diff --git a/README.md b/README.md
index 2d885a3779..31fc280e04 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,6 @@ torchtune currently supports the following models.
| [Code-Llama2](https://ai.meta.com/blog/code-llama-large-language-model-coding/) | 7B, 13B, 70B [[models](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] |
| [Mistral](https://huggingface.co/mistralai) | 7B [[models](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] |
| [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B, 7B [[models](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] |
-| [Gemma2](https://huggingface.co/docs/transformers/main/en/model_doc/gemma2) | 2B, 9B, 27B [[models](torchtune/models/gemma2/_model_builders.py), [configs](recipes/configs/gemma2/)] |
| [Microsoft Phi3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) | Mini [[models](torchtune/models/phi3/), [configs](recipes/configs/phi3/)]
| [Qwen2](https://qwenlm.github.io/blog/qwen2/) | 0.5B, 1.5B, 7B [[models](torchtune/models/qwen2/), [configs](recipes/configs/qwen2/)]
diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst
index b2d74022b1..36175bb392 100644
--- a/docs/source/api_ref_models.rst
+++ b/docs/source/api_ref_models.rst
@@ -361,37 +361,6 @@ To download the Gemma 7B model:
gemma.gemma_tokenizer
-gemma2 :
---------
-
-Models of size 2B, 9B, 27B from the `Gemma family `_.
-
-Important: You need to request access on `Hugging Face `__ to use this model.
-
-To download the Gemma2 2B, 9B, 27B models :
-
-.. code-block:: bash
-
- tune download google/gemma-2-b --ignore-patterns "gemma-2-b.gguf" --hf-token
-
-
-.. autosummary::
- :toctree: generated/
- :nosignatures:
-
- gemma2.gemma2
- gemma2.lora_gemma2
- gemma2.gemma2_2b
- gemma2.lora_gemma2_2b
- gemma2.qlora_gemma2_2b
- gemma2.gemma2_9b
- gemma2.lora_gemma2_9b
- gemma2.qlora_gemma2_9b
- gemma2.gemma2_27b
- gemma2.lora_gemma2_27b
- gemma2.qlora_gemma2_27b
- gemma.gemma_tokenizer
-
clip
----
diff --git a/docs/source/tutorials/memory_optimizations.rst b/docs/source/tutorials/memory_optimizations.rst
index a0f6d16c91..aa75024e6a 100644
--- a/docs/source/tutorials/memory_optimizations.rst
+++ b/docs/source/tutorials/memory_optimizations.rst
@@ -167,7 +167,7 @@ In addition to :ref:`reducing model and optimizer precision
All of our recipes support lower-precision optimizers from the `torchao `_ library.
For single device recipes, we also support `bitsandbytes `_.
-A good place to start might be the :class:`torchao.prototype.low_bit_optim.AdamW8bit` and :class:`bitsandbytes.optim.PagedAdamW8bit` optimizers.
+A good place to start might be the :class:`torchao.prototype.low_bit_optim.torchao.AdamW8bit` and :class:`bitsandbytes.optim.PagedAdamW8bit` optimizers.
Both reduce memory by quantizing the optimizer state dict. Paged optimizers will also offload to CPU if there isn't enough GPU memory available. In practice,
you can expect higher memory savings from bnb's PagedAdamW8bit but higher training speed from torchao's AdamW8bit.
@@ -180,7 +180,7 @@ a low precision optimizer using the :ref:`cli_label`:
.. code-block:: bash
tune run --config \
- optimizer=torchao.prototype.low_bit_optim.AdamW8bit
+ optimizer=torchao.prototype.low_bit_optim.torchao.AdamW8bit
.. code-block:: bash
diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml
deleted file mode 100644
index ddc89b38b2..0000000000
--- a/recipes/configs/gemma2/27B_full.yaml
+++ /dev/null
@@ -1,74 +0,0 @@
-# Config for multi-device full finetuning in full_finetune_distributed.py
-# using a gemma2 27B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token
-#
-# To launch on 4 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full checkpointer.checkpoint_dir=
-#
-# This config works only when the model is being fine-tuned on 2+ GPUs.
-
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-27b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.gemma2_27b
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-27b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00024"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b
- model_type: GEMMA2
-resume_from_checkpoint: False
-
-# Fine-tuning arguments
-batch_size: 1
-epochs: 1
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-finetune
-log_every_n_steps: 1
-log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml
deleted file mode 100644
index a138441199..0000000000
--- a/recipes/configs/gemma2/27B_lora.yaml
+++ /dev/null
@@ -1,86 +0,0 @@
-# Config for multi-device LoRA finetuning in lora_finetune_distributed.py
-# using a gemma2 27B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token
-#
-# To launch on 4 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora checkpointer.checkpoint_dir=
-#
-# This config works only when the model is being fine-tuned on 2+ GPUs.
-
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-27b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.lora_gemma2_27b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-27b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00024"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b/
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 4
-epochs: 3
-max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml
deleted file mode 100644
index 577b0715c5..0000000000
--- a/recipes/configs/gemma2/27B_lora_single_device.yaml
+++ /dev/null
@@ -1,112 +0,0 @@
-# Config for multi-device LoRA finetuning in lora_finetune_single_device.py
-# using a gemma2 27B model
-#
-# This config assumes that you've run the following command before launching
-# this run (torchtune does not use gguf so you can ignore it to save time and space):
-# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token
-#
-# To launch on a single device, run the following command from root:
-# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device checkpointer.checkpoint_dir=
-#
-# This config works only for training on single device.
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-27b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.lora_gemma2_27b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 8
- lora_alpha: 16
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-27b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00024"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b/
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 5e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 2
-epochs: 1
-max_steps_per_epoch: null
-gradient_accumulation_steps: 8
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
-
-# Show case the usage of pytorch profiler
-# Set enabled to False as it's only needed for debugging training
-profiler:
- _component_: torchtune.training.setup_torch_profiler
- enabled: False
-
- #Output directory of trace artifacts
- output_dir: ${output_dir}/profiling_outputs
-
- #`torch.profiler.ProfilerActivity` types to trace
- cpu: True
- cuda: True
-
- #trace options passed to `torch.profiler.profile`
- profile_memory: False
- with_stack: False
- record_shapes: True
- with_flops: False
-
- # `torch.profiler.schedule` options:
- # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
- wait_steps: 5
- warmup_steps: 5
- active_steps: 2
- num_cycles: 1
diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml
deleted file mode 100644
index 14d9b75ba7..0000000000
--- a/recipes/configs/gemma2/27B_qlora_single_device.yaml
+++ /dev/null
@@ -1,115 +0,0 @@
-# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py
-# using a gemma2 27B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token
-#
-# To launch on a single device, run the following command from root:
-# tune run lora_finetune_single_device --config gemma2/27B_qlora_single_device
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run lora_finetune_single_device --config gemma2/27B_qlora_single_device checkpointer.checkpoint_dir=
-#
-# This config works only for training on single device.
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-27b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.qlora_gemma2_27b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-27b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00024"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b/
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 4
-epochs: 3
-max_steps_per_epoch: null
-gradient_accumulation_steps: 4
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
-
-# Show case the usage of pytorch profiler
-# Set enabled to False as it's only needed for debugging training
-profiler:
- _component_: torchtune.training.setup_torch_profiler
- enabled: False
-
- #Output directory of trace artifacts
- output_dir: ${output_dir}/profiling_outputs
-
- #`torch.profiler.ProfilerActivity` types to trace
- cpu: True
- cuda: True
-
- #trace options passed to `torch.profiler.profile`
- profile_memory: False
- with_stack: False
- record_shapes: True
- with_flops: False
-
- # `torch.profiler.schedule` options:
- # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
- wait_steps: 5
- warmup_steps: 5
- active_steps: 2
- num_cycles: 1
-
-# For colab use True
-low_cpu_ram: False
diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml
deleted file mode 100644
index e302dd759d..0000000000
--- a/recipes/configs/gemma2/2B_full.yaml
+++ /dev/null
@@ -1,76 +0,0 @@
-# Config for multi-device full finetuning in full_finetune_distributed.py
-# using a gemma2 2B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token
-#
-# To launch on 4 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full checkpointer.checkpoint_dir=
-#
-# This config works only when the model is being fine-tuned on 2+ GPUs.
-
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-2b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.gemma2_2b
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-2b/
- checkpoint_files: [
- model-00001-of-00003.safetensors,
- model-00002-of-00003.safetensors,
- model-00003-of-00003.safetensors,
- ]
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
- model_type: GEMMA2
-resume_from_checkpoint: False
-
-# Fine-tuning arguments
-batch_size: 2
-epochs: 3
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-finetune
-log_every_n_steps: 1
-log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml
deleted file mode 100644
index 9a439ee0a3..0000000000
--- a/recipes/configs/gemma2/2B_lora.yaml
+++ /dev/null
@@ -1,88 +0,0 @@
-# Config for multi-device LoRA finetuning in lora_finetune_distributed.py
-# using a gemma2 2B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token
-#
-# To launch on 4 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/2B_lora
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/2B_lora checkpointer.checkpoint_dir=
-#
-# This config works only when the model is being fine-tuned on 2+ GPUs.
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-2b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.lora_gemma2_2b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-2b/
- checkpoint_files: [
- model-00001-of-00003.safetensors,
- model-00002-of-00003.safetensors,
- model-00003-of-00003.safetensors,
- ]
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
- model_type: GEMMA2
-resume_from_checkpoint: False
-
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 4
-epochs: 3
-max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml
deleted file mode 100644
index 1a2703fb47..0000000000
--- a/recipes/configs/gemma2/2B_lora_single_device.yaml
+++ /dev/null
@@ -1,114 +0,0 @@
-# Config for multi-device LoRA finetuning in lora_finetune_single_device.py
-# using a gemma2 2B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token
-#
-# To launch on a single device, run the following command from root:
-# tune run lora_finetune_single_device --config gemma2/2B_lora_single_device
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run lora_finetune_single_device --config gemma2/2B_lora_single_device checkpointer.checkpoint_dir=
-#
-# This config works only for training on single device.
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-2b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.lora_gemma2_2b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-2b/
- checkpoint_files: [
- model-00001-of-00003.safetensors,
- model-00002-of-00003.safetensors,
- model-00003-of-00003.safetensors,
- ]
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 8
-epochs: 3
-max_steps_per_epoch: null
-gradient_accumulation_steps: 2
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
-
-# Show case the usage of pytorch profiler
-# Set enabled to False as it's only needed for debugging training
-profiler:
- _component_: torchtune.training.setup_torch_profiler
- enabled: False
-
- #Output directory of trace artifacts
- output_dir: ${output_dir}/profiling_outputs
-
- #`torch.profiler.ProfilerActivity` types to trace
- cpu: True
- cuda: True
-
- #trace options passed to `torch.profiler.profile`
- profile_memory: False
- with_stack: False
- record_shapes: True
- with_flops: False
-
- # `torch.profiler.schedule` options:
- # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
- wait_steps: 5
- warmup_steps: 5
- active_steps: 2
- num_cycles: 1
diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml
deleted file mode 100644
index c2525460ff..0000000000
--- a/recipes/configs/gemma2/2B_qlora_single_device.yaml
+++ /dev/null
@@ -1,114 +0,0 @@
-# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py
-# using a gemma2 2B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token
-#
-# To launch on a single device, run the following command from root:
-# tune run lora_finetune_single_device --config gemma2/2B_qlora_single_device
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run lora_finetune_single_device --config gemma2/2B_qlora_single_device checkpointer.checkpoint_dir=
-#
-# This config works only for training on single device.
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-2b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.qlora_gemma2_2b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-2b/
- checkpoint_files: [
- model-00001-of-00003.safetensors,
- model-00002-of-00003.safetensors,
- model-00003-of-00003.safetensors,
- ]
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 4
-epochs: 3
-max_steps_per_epoch: null
-gradient_accumulation_steps: 4
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
-
-# Show case the usage of pytorch profiler
-# Set enabled to False as it's only needed for debugging training
-profiler:
- _component_: torchtune.training.setup_torch_profiler
- enabled: False
-
- #Output directory of trace artifacts
- output_dir: ${output_dir}/profiling_outputs
-
- #`torch.profiler.ProfilerActivity` types to trace
- cpu: True
- cuda: True
-
- #trace options passed to `torch.profiler.profile`
- profile_memory: False
- with_stack: False
- record_shapes: True
- with_flops: False
-
- # `torch.profiler.schedule` options:
- # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
- wait_steps: 5
- warmup_steps: 5
- active_steps: 2
- num_cycles: 1
diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml
deleted file mode 100644
index 0fc7e6e4e4..0000000000
--- a/recipes/configs/gemma2/9B_full.yaml
+++ /dev/null
@@ -1,74 +0,0 @@
-# Config for multi-device full finetuning in full_finetune_distributed.py
-# using a gemma2 9B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token
-#
-# To launch on 4 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/9B_full
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/9B_full checkpointer.checkpoint_dir=
-#
-# This config works only when the model is being fine-tuned on 2+ GPUs.
-
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-9b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.gemma2_9b
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-9b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00008"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b
- model_type: GEMMA2
-resume_from_checkpoint: False
-
-# Fine-tuning arguments
-batch_size: 1
-epochs: 1
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-finetune
-log_every_n_steps: 1
-log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml
deleted file mode 100644
index 960e4fa881..0000000000
--- a/recipes/configs/gemma2/9B_lora.yaml
+++ /dev/null
@@ -1,86 +0,0 @@
-# Config for multi-device LoRA finetuning in lora_finetune_distributed.py
-# using a gemma2 9B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token
-#
-# To launch on 4 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/9B_lora
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/9B_lora checkpointer.checkpoint_dir=
-#
-# This config works only when the model is being fine-tuned on 2+ GPUs.
-
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-9b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.lora_gemma2_9b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-9b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00008"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b/
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 4
-epochs: 3
-max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml
deleted file mode 100644
index e9d6c22a73..0000000000
--- a/recipes/configs/gemma2/9B_lora_single_device.yaml
+++ /dev/null
@@ -1,112 +0,0 @@
-# Config for multi-device LoRA finetuning in lora_finetune_single_device.py
-# using a gemma2 9B model
-#
-# This config assumes that you've run the following command before launching
-# this run (torchtune does not use gguf so you can ignore it to save time and space):
-# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token
-#
-# To launch on a single device, run the following command from root:
-# tune run lora_finetune_single_device --config gemma2/9B_lora_single_device
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run lora_finetune_single_device --config gemma2/9B_lora_single_device checkpointer.checkpoint_dir=
-#
-# This config works only for training on single device.
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-9b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.lora_gemma2_9b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 8
- lora_alpha: 16
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-9b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00008"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b/
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 5e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 8
-epochs: 1
-max_steps_per_epoch: null
-gradient_accumulation_steps: 2
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
-
-# Show case the usage of pytorch profiler
-# Set enabled to False as it's only needed for debugging training
-profiler:
- _component_: torchtune.training.setup_torch_profiler
- enabled: False
-
- #Output directory of trace artifacts
- output_dir: ${output_dir}/profiling_outputs
-
- #`torch.profiler.ProfilerActivity` types to trace
- cpu: True
- cuda: True
-
- #trace options passed to `torch.profiler.profile`
- profile_memory: False
- with_stack: False
- record_shapes: True
- with_flops: False
-
- # `torch.profiler.schedule` options:
- # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
- wait_steps: 5
- warmup_steps: 5
- active_steps: 2
- num_cycles: 1
diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml
deleted file mode 100644
index 8991ba9ece..0000000000
--- a/recipes/configs/gemma2/9B_qlora_single_device.yaml
+++ /dev/null
@@ -1,115 +0,0 @@
-# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py
-# using a gemma2 9B model
-#
-# This config assumes that you've run the following command before launching
-# this run:
-# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token
-#
-# To launch on a single device, run the following command from root:
-# tune run lora_finetune_single_device --config gemma2/9B_qlora_single_device
-#
-# You can add specific overrides through the command line. For example
-# to override the checkpointer directory while launching training
-# you can run:
-# tune run lora_finetune_single_device --config gemma2/9B_qlora_single_device checkpointer.checkpoint_dir=
-#
-# This config works only for training on single device.
-
-# Tokenizer
-tokenizer:
- _component_: torchtune.models.gemma.gemma_tokenizer
- path: /tmp/gemma-2-9b/tokenizer.model
-
-# Dataset
-dataset:
- packed: False # Set to true for great speed ups
- _component_: torchtune.datasets.alpaca_dataset
-seed: null
-shuffle: True
-
-# Model Arguments
-model:
- _component_: torchtune.models.gemma2.qlora_gemma2_9b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
- apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
- lora_dropout: 0.0
-
-checkpointer:
- _component_: torchtune.training.FullModelHFCheckpointer
- checkpoint_dir: /tmp/gemma-2-9b/
- checkpoint_files:
- filename_format: model-{}-of-{}.safetensors
- max_filename: "00008"
- recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b/
- model_type: GEMMA2
-resume_from_checkpoint: False
-save_adapter_weights_only: False
-
-optimizer:
- _component_: torch.optim.AdamW
- fused: True
- lr: 2e-5
-
-lr_scheduler:
- _component_: torchtune.modules.get_cosine_schedule_with_warmup
- num_warmup_steps: 10
-
-loss:
- _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-
-# Fine-tuning arguments
-batch_size: 4
-epochs: 3
-max_steps_per_epoch: null
-gradient_accumulation_steps: 4
-compile: False # pytorch compile, set to true for perf/memory improvement
-
-# Training env
-device: cuda
-
-# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
-
-# Reduced precision
-dtype: bf16
-
-# Logging
-metric_logger:
- _component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-lora
-log_every_n_steps: 1
-log_peak_memory_stats: True
-
-# Show case the usage of pytorch profiler
-# Set enabled to False as it's only needed for debugging training
-profiler:
- _component_: torchtune.training.setup_torch_profiler
- enabled: False
-
- #Output directory of trace artifacts
- output_dir: ${output_dir}/profiling_outputs
-
- #`torch.profiler.ProfilerActivity` types to trace
- cpu: True
- cuda: True
-
- #trace options passed to `torch.profiler.profile`
- profile_memory: False
- with_stack: False
- record_shapes: True
- with_flops: False
-
- # `torch.profiler.schedule` options:
- # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
- wait_steps: 5
- warmup_steps: 5
- active_steps: 2
- num_cycles: 1
-
-# For colab use True
-low_cpu_ram: False
diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml
index e404b0c4dc..0cbf6c7b7a 100644
--- a/recipes/configs/llama2/7B_qat_full.yaml
+++ b/recipes/configs/llama2/7B_qat_full.yaml
@@ -67,7 +67,7 @@ device: cuda
# Memory management
enable_activation_checkpointing: True # True reduces memory
-enable_activation_offloading: False # True reduces memory
+memory_efficient_fsdp_wrap: False
# Reduced precision
dtype: bf16
diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml
index 2b08cbb10f..ce409d1bbb 100644
--- a/recipes/configs/llama3/8B_qat_full.yaml
+++ b/recipes/configs/llama3/8B_qat_full.yaml
@@ -44,6 +44,8 @@ resume_from_checkpoint: False
# Fine-tuning arguments
batch_size: 2
epochs: 3
+compile: False # pytorch compile, set to true for better perf/memory
+optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# QAT arguments
quantizer:
@@ -58,16 +60,13 @@ loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
-optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
device: cuda
# Memory management
enable_activation_checkpointing: True # True reduces memory
-enable_activation_offloading: False # True reduces memory
-custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+memory_efficient_fsdp_wrap: True
# Reduced precision
dtype: bf16
@@ -76,7 +75,7 @@ dtype: bf16
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
-output_dir: /tmp/full-llama3-finetune
+output_dir: /tmp/alpaca-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py
index fcdb3e4ea5..6d66431357 100644
--- a/recipes/lora_finetune_single_device.py
+++ b/recipes/lora_finetune_single_device.py
@@ -632,6 +632,7 @@ def save_checkpoint(self, epoch: int) -> None:
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
+
# run model
with self.activations_handling_ctx:
logits = self._model(**batch)
diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py
index 1aa622ba63..039540eff5 100644
--- a/recipes/qat_distributed.py
+++ b/recipes/qat_distributed.py
@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+import os
import sys
import time
@@ -20,13 +21,11 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
-from torchtune.config._utils import _get_component_from_path
-from torchtune.data import padded_collate_packed
+from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
-from torchtune.training.lr_schedulers import get_lr
from tqdm import tqdm
@@ -51,7 +50,7 @@ class QATRecipeDistributed(FTRecipeInterface):
to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``.
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
- is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
+ is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
@@ -63,18 +62,6 @@ class QATRecipeDistributed(FTRecipeInterface):
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.
- - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
- flag. Activation offloading is a technique similar to activations checkpointing that helps
- reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
- checkpointing drops the activation in the forward to recompute it later in the backward,
- activations offloading will drop the activation in the forward to the CPU and bring it
- back during the backward pass. As always, there is a tradeoff--these savings in memory can
- come at the cost of training performance and CPU resources. To recover some runtime cost,
- we've added an option to enable offloading on a different stream to permit overlapping with
- the computation. This option is currently only available on PyTorch 2.5 or later and will
- be enabled by default if an acceptable torch version is found. Activation offloading can be
- used in conjunction with activation checkpointing.
-
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
@@ -106,10 +93,6 @@ class QATRecipeDistributed(FTRecipeInterface):
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
- - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
- ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
- ``clip_grad_norm='inf'``.
-
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
@@ -119,9 +102,6 @@ class QATRecipeDistributed(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
- RuntimeError: If ``left_pad_sequence`` is set as the data collator.
- RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
- RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
"""
def __init__(self, cfg: DictConfig) -> None:
@@ -161,50 +141,12 @@ def __init__(self, cfg: DictConfig) -> None:
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
- self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
- self._clip_grad_norm = cfg.get("clip_grad_norm", None)
+ self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[
+ cfg.get("fsdp_sharding_strategy", "FULL_SHARD")
+ ]
self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None)
self._quantizer_mode = None
- # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
- if self._optimizer_in_bwd:
- if self._clip_grad_norm is not None:
- raise RuntimeError(
- "Gradient clipping is not supported with optimizer in bwd."
- "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
- )
- if self._gradient_accumulation_steps > 1:
- raise RuntimeError(
- "Gradient accumulation is not supported with optimizer in bwd."
- "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
- )
-
- # activation checkpointing/offloading
- self._enable_activation_checkpointing = cfg.get(
- "enable_activation_checkpointing", False
- )
- self._enable_activation_offloading = cfg.get(
- "enable_activation_offloading", False
- )
- if self._enable_activation_offloading:
- if self._device.type != "cuda":
- raise RuntimeError(
- "enable_activation_offloading should only be True when training on CUDA"
- )
- if not self._enable_activation_checkpointing:
- raise RuntimeError(
- "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
- )
- elif (
- self._enable_activation_checkpointing
- and cfg.checkpointer.model_type != "LLAMA3_VISION"
- ):
- utils.log_rank_zero(
- log,
- "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
- "Enabling activation offloading should reduce memory further.",
- )
-
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = training.set_seed(seed=cfg.seed)
@@ -281,11 +223,10 @@ def setup(self, cfg: DictConfig) -> None:
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
- self._compile = cfg.get("compile", False)
+ self._model_compile = cfg.get("compile", False)
self._model = self._setup_model(
cfg_model=cfg.model,
- enable_activation_checkpointing=self._enable_activation_checkpointing,
- enable_activation_offloading=self._enable_activation_offloading,
+ enable_activation_checkpointing=cfg.enable_activation_checkpointing,
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
@@ -298,7 +239,6 @@ def setup(self, cfg: DictConfig) -> None:
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
- optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
@@ -308,25 +248,30 @@ def setup(self, cfg: DictConfig) -> None:
# initialize loss
self._loss_fn = config.instantiate(cfg.loss)
-
- if self._compile:
- training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
-
+ backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
-
- if self._is_rank_zero:
- log.info("Loss is initialized.")
+ if self._model_compile:
+ log.info("Compiling loss with torch.compile...")
+ # For CEWithChunkedOutputLoss, if we compile the entire class
+ # we lose the benefits from the chunked loss.
+ # Therefore, we only compile the cross entropy function + upcasting
+ self._loss_fn.compute_cross_entropy = torch.compile(
+ self._loss_fn.compute_cross_entropy, backend=backend
+ )
+ else:
+ if self._model_compile:
+ log.info("Compiling loss with torch.compile...")
+ self._loss_fn = torch.compile(self._loss_fn, backend=backend)
+ log.info("Loss is initialized.")
# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
- collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
- collate_fn=collate_name,
)
# Finally update the recipe state which can only be correctly set after all of the
@@ -426,7 +371,6 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
- enable_activation_offloading: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
@@ -452,9 +396,6 @@ def _setup_model(
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
- if self._compile:
- training.compile_model(model, verbose=self._is_rank_zero)
-
# We currently have two versions of activation checkpointing in this recipe
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
# the older version of AC and this behavior is unchanged
@@ -510,17 +451,7 @@ def _setup_model(
# This method will convert the full model state dict into a sharded state
# dict and load into the model
training.load_from_full_model_state_dict(
- model,
- model_state_dict,
- self._device,
- self._is_rank_zero,
- strict=True,
- cpu_offload=fsdp_cpu_offload,
- )
-
- # activation offloading
- self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
- model, enable_activation_offloading
+ model, model_state_dict, self._device, self._is_rank_zero, strict=True
)
# Ensure no params and buffers are on meta device
@@ -539,64 +470,25 @@ def _setup_model(
return model
def _setup_optimizer(
- self,
- cfg_optimizer: DictConfig,
- optimizer_in_bwd: bool = False,
- opt_state_dict: Optional[Dict[str, Any]] = None,
- ) -> Optional[Optimizer]:
- if optimizer_in_bwd:
- # Maintain a dict of optims for every parameter.
- optim_dict = {
- param: config.instantiate(cfg_optimizer, [param])
- for param in self._model.parameters()
- }
-
- # Register optimizer step hooks on the model to run optimizer in backward.
- training.register_optim_in_bwd_hooks(
- model=self._model, optim_dict=optim_dict
+ self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
+ ) -> Optimizer:
+ optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
+ if opt_state_dict:
+ training.load_from_full_optimizer_state_dict(
+ optimizer,
+ opt_state_dict,
+ self._device,
)
- # Create a wrapper for checkpoint save/load of optimizer states when running in backward.
- self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
- model=self._model, optim_dict=optim_dict
- )
- # Load optimizer states for each param. If optimizer states are being restored in an optimizer in
- # backward run, these need to have been saved with the same setting. Cannot restore from runs that
- # did not use optimizer in backward.
- if opt_state_dict is not None:
- for param in opt_state_dict.keys():
- try:
- training.load_from_full_optimizer_state_dict(
- self._optim_ckpt_wrapper.state_dict()[param],
- opt_state_dict[param],
- self._device,
- )
- except BaseException as e:
- raise RuntimeError(
- "Failed loading in-backward optimizer checkpoints."
- "Please make sure run being restored from was using in-backward optimizer."
- ) from e
- if self._is_rank_zero:
- log.info("In-backward optimizers are set up.")
- return None
- else:
- optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
- if opt_state_dict:
- training.load_from_full_optimizer_state_dict(
- optimizer,
- opt_state_dict,
- self._device,
- )
- if self._is_rank_zero:
- log.info("Optimizer is initialized.")
- return optimizer
+ if self._is_rank_zero:
+ log.info("Optimizer is initialized.")
+ return optimizer
def _setup_data(
self,
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
- collate_fn: str,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports the
@@ -607,20 +499,15 @@ def _setup_data(
if isinstance(cfg_dataset, ListConfig):
datasets = [
- config.instantiate(single_cfg_dataset, self._tokenizer)
+ config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
- ds = config.instantiate(cfg_dataset, self._tokenizer)
+ ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
packed = cfg_dataset.get("packed", False)
- # Instantiate collate_fn
- if "left_pad_sequence" in collate_fn:
- raise RuntimeError("left_pad_sequence collator is only for inference.")
- collate_fn = _get_component_from_path(collate_fn)
-
sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
@@ -632,12 +519,14 @@ def _setup_data(
drop_last=True,
collate_fn=(
partial(
- collate_fn,
+ padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
- else padded_collate_packed
+ else partial(
+ padded_collate_packed,
+ )
),
)
@@ -664,54 +553,25 @@ def save_checkpoint(
checkpoint_dict = {}
intermediate_checkpoint = epoch + 1 < self.total_epochs
-
- if self._is_rank_zero:
- log.info(
- "Saving checkpoint. This may take some time. Retrieving full model state dict..."
- )
- start = time.perf_counter()
-
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
- device=self._device,
)
- if self._is_rank_zero:
- log.info(
- f"Getting full model state dict took {time.perf_counter() - start:.2f} secs"
- )
-
if intermediate_checkpoint:
- start = time.perf_counter()
- if self._is_rank_zero:
- log.info("Getting optimizer state dict...")
- if not self._optimizer_in_bwd:
- opt_state_dict = training.get_full_optimizer_state_dict(
- self._optimizer,
- self._is_rank_zero,
- device=self._device,
- )
- else:
- opt_state_dict = {}
- for param, opt in self._optim_ckpt_wrapper.optim_map.items():
- opt_state_dict[param] = training.get_full_optimizer_state_dict(
- opt, self._is_rank_zero, device=self._device
- )
- if self._is_rank_zero:
- log.info(
- f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs"
- )
+ opt_state_dict = training.get_full_optimizer_state_dict(
+ self._optimizer,
+ self._is_rank_zero,
+ )
else:
opt_state_dict = None
# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
-
if self._is_rank_zero:
- start = time.perf_counter()
+
checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})
# if training is in-progress, checkpoint the optimizer state and recipe state
@@ -732,9 +592,6 @@ def save_checkpoint(
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
- log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")
-
- torch.distributed.barrier()
def train(self) -> None:
"""
@@ -742,15 +599,10 @@ def train(self) -> None:
"""
# clean up before training begins
training.cleanup_before_training()
-
world_size, rank = training.get_world_size_and_rank()
# zero out the gradients before starting training
- if not self._optimizer_in_bwd:
- self._optimizer.zero_grad()
- else:
- for opt in self._optim_ckpt_wrapper.optim_map.values():
- opt.zero_grad()
+ self._optimizer.zero_grad()
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
@@ -760,6 +612,7 @@ def train(self) -> None:
self._profiler.start()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
+
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
@@ -782,6 +635,13 @@ def train(self) -> None:
):
torch.cuda.memory._record_memory_history()
+ # Both are shape [b, s]
+ tokens, labels = batch["tokens"], batch["labels"]
+ # Get the attention mask and position ids from the dataset if they
+ # exist. Currently, only sample packing in PackedDataset returns these
+ mask = batch.get("mask", None) # shape [b, s, s]
+ input_pos = batch.get("input_pos", None) # shape [b, s]
+
# Optionally wait N steps before enabling fake quant
if self._fake_quant_after_n_steps is not None:
if self.global_step == 0:
@@ -803,20 +663,20 @@ def train(self) -> None:
)
self._model.apply(enable_fq)
- utils.batch_to_device(batch, self._device)
+ tokens = tokens.to(self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
+
+ utils.batch_to_device(batch, self._device)
+
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
-
- # Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
- with self.activations_handling_ctx:
- logits = self._model(**batch)
+ logits = self._model(**batch)
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
@@ -829,40 +689,25 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))
# Compute loss
- # Loss is normalized by default so we multiply by the number of tokens
- # This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_fn(logits, labels) * current_num_tokens
# free logits otherwise it peaks backward memory
del logits
running_loss += current_loss
-
- # For optimizer in backward, we need to normalize before calling backward
- # This case and gradient accumulation are mutually exclusive
- if self._optimizer_in_bwd:
- torch.distributed.all_reduce(num_tokens)
- torch.distributed.all_reduce(running_loss)
- current_loss = current_loss / num_tokens
-
current_loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
- if not self._optimizer_in_bwd:
- # Get total number of tokens across all ranks to normalize gradients
- torch.distributed.all_reduce(num_tokens)
- # This will ensure that the logged loss matches what we're optimizing
- torch.distributed.all_reduce(running_loss)
- # Manually scale the gradients from unnormalized loss by total # of tokens
- training.scale_grads(self._model, 1 / num_tokens)
- if self._clip_grad_norm is not None:
- grad_norm = torch.nn.utils.clip_grad_norm_(
- self._model.parameters(),
- max_norm=float(self._clip_grad_norm),
- )
- self._optimizer.step()
- self._optimizer.zero_grad(set_to_none=True)
+ # Get total number of tokens across all ranks to normalize gradients
+ torch.distributed.all_reduce(num_tokens)
+ # This will ensure that the logged loss matches what we're optimizing
+ torch.distributed.all_reduce(running_loss)
+ # Manually scale the gradients from unnormalized loss by total # of tokens
+ training.scale_grads(self._model, 1 / num_tokens)
+
+ self._optimizer.step()
+ self._optimizer.zero_grad(set_to_none=True)
# Update the number of steps when the weights are updated
self.global_step += 1
@@ -881,22 +726,15 @@ def train(self) -> None:
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
- "lr": get_lr(
- (
- self._optimizer
- if not self._optimizer_in_bwd
- else self._optim_ckpt_wrapper
- ),
+ "lr": self._optimizer.param_groups[0]["lr"],
+ "tokens_per_second_per_gpu": (
+ num_tokens / time_per_step * world_size
),
- "tokens_per_second_per_gpu": num_tokens
- / (time_per_step * world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
training.get_memory_stats(device=self._device)
)
- if self._clip_grad_norm is not None:
- log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
@@ -946,7 +784,7 @@ def recipe_main(cfg: DictConfig) -> None:
"""
if not training.is_distributed():
raise RuntimeError(
- "Distributed finetune recipe should be run via a distributed launcher."
+ "Distributed QAT recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
diff --git a/tests/cache_artifacts.sh b/tests/cache_artifacts.sh
index 81b50b5889..230d26dba0 100755
--- a/tests/cache_artifacts.sh
+++ b/tests/cache_artifacts.sh
@@ -18,6 +18,9 @@ SMALL_MODEL_URLS=(
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-03082024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-tune-llama3-05052024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-reward-07122024.pt"
+ "https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-meta-vision-10172024.pt"
+ "https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-vision-10172024.pt"
+
)
FULL_MODEL_URL=("s3://pytorch-multimodal/llama2-7b-torchtune.pt")
TOKENIZER_URLS=(
diff --git a/tests/recipes/test_eleuther_eval.py b/tests/recipes/test_eleuther_eval.py
index 1c3a7bb65f..ec71fe3312 100644
--- a/tests/recipes/test_eleuther_eval.py
+++ b/tests/recipes/test_eleuther_eval.py
@@ -13,7 +13,12 @@
import pytest
from tests.common import TUNE_PATH
-from tests.recipes.utils import llama2_test_config, write_hf_ckpt_config
+from tests.recipes.utils import (
+ llama2_test_config,
+ llama3_2_vision_test_config,
+ write_hf_ckpt_config,
+ write_hf_vision_ckpt_config,
+)
from tests.test_utils import CKPT_MODEL_PATHS
@@ -26,6 +31,30 @@ class TestEleutherEval:
("truthfulqa_mc2", 0.4, 4),
],
)
+ @pytest.fixture
+ def hide_correct_version_number(self, monkeypatch):
+ import importlib.metadata
+
+ import_orig = importlib.metadata.version
+
+ def mocked_import(name, *args, **kwargs):
+ if name == "lm-eval":
+ return "0.4.4" # Hardcode wrong version number
+ return import_orig(name, *args, **kwargs)
+
+ monkeypatch.setattr(importlib.metadata, "version", mocked_import)
+
+ @pytest.fixture
+ def expected_vision_acc(self):
+ return {
+ "Science": 0.2,
+ "Biology": 0.4,
+ "Chemistry": 0.0,
+ "Geography": 0.2,
+ "Math": 0.4,
+ "Physics": 0.0,
+ }
+
@pytest.mark.integration_test
def test_torchtune_checkpoint_eval_results(
self, caplog, monkeypatch, tmpdir, eval_name, expected_acc, bsz
@@ -74,22 +103,9 @@ def test_torchtune_checkpoint_eval_results(
acc_result = float(search_results.group(1))
assert math.isclose(acc_result, expected_acc, abs_tol=0.05)
- @pytest.fixture
- def hide_correct_version_number(self, monkeypatch):
- import importlib.metadata
-
- import_orig = importlib.metadata.version
-
- def mocked_import(name, *args, **kwargs):
- if name == "lm-eval":
- return "0.4.4" # Hardcode wrong version number
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(importlib.metadata, "version", mocked_import)
-
@pytest.mark.integration_test
@pytest.mark.usefixtures("hide_correct_version_number")
- def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
+ def test_eval_recipe_errors_without_lm_eval(self, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
@@ -123,7 +139,7 @@ def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
@pytest.mark.integration_test
def test_eval_recipe_errors_with_quantization_hf_checkpointer(
- self, capsys, monkeypatch, tmpdir
+ self, monkeypatch, tmpdir
):
ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
@@ -162,7 +178,7 @@ def test_eval_recipe_errors_with_quantization_hf_checkpointer(
runpy.run_path(TUNE_PATH, run_name="__main__")
@pytest.mark.integration_test
- def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir):
+ def test_eval_recipe_errors_with_qat_quantizer(self, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
@@ -194,3 +210,84 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ @pytest.mark.integration_test
+ def test_meta_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
+ ckpt = "llama3_2_vision_meta"
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ ckpt_dir = ckpt_path.parent
+
+ cmd = f"""
+ tune run eleuther_eval \
+ --config llama3_2_vision/11B_evaluation \
+ output_dir={tmpdir} \
+ checkpointer=torchtune.training.FullModelMetaCheckpointer \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}] \
+ ~checkpointer.checkpoint_files.filename_format \
+ ~checkpointer.checkpoint_files.max_filename \
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type=LLAMA3_VISION \
+ tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
+ tokenizer.prompt_template=null \
+ limit=5 \
+ dtype=bf16 \
+ device=cpu \
+ """.split()
+
+ model_config = llama3_2_vision_test_config()
+ cmd = cmd + model_config
+
+ monkeypatch.setattr(sys, "argv", cmd)
+ with pytest.raises(SystemExit, match=""):
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ out = caplog.text
+
+ pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"
+
+ matches = re.findall(pattern, out, re.MULTILINE)
+ for task_name, _, accuracy in matches:
+ assert math.isclose(float(accuracy), expected_vision_acc[task_name])
+
+ @pytest.mark.integration_test
+ def test_hf_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
+ ckpt = "llama3_2_vision_hf"
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ ckpt_dir = ckpt_path.parent
+
+ # Config file needed for model conversion.
+ write_hf_vision_ckpt_config(ckpt_dir)
+
+ cmd = f"""
+ tune run eleuther_eval \
+ --config llama3_2_vision/11B_evaluation \
+ output_dir={tmpdir} \
+ checkpointer=torchtune.training.FullModelHFCheckpointer \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ ~checkpointer.checkpoint_files.filename_format \
+ ~checkpointer.checkpoint_files.max_filename \
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type=LLAMA3_VISION \
+ tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
+ tokenizer.prompt_template=null \
+ limit=5 \
+ dtype=bf16 \
+ device=cpu \
+ """.split()
+
+ model_config = llama3_2_vision_test_config()
+ cmd = cmd + model_config
+
+ monkeypatch.setattr(sys, "argv", cmd)
+ with pytest.raises(SystemExit, match=""):
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ out = caplog.text
+
+ pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"
+
+ matches = re.findall(pattern, out, re.MULTILINE)
+ for task_name, _, accuracy in matches:
+ assert math.isclose(float(accuracy), expected_vision_acc[task_name])
diff --git a/tests/recipes/utils.py b/tests/recipes/utils.py
index baa8ad23a9..7c35eedc2a 100644
--- a/tests/recipes/utils.py
+++ b/tests/recipes/utils.py
@@ -128,6 +128,58 @@ def llama3_test_config() -> List[str]:
]
+def llama3_2_vision_test_config() -> List[str]:
+ return [
+ "model=tests.recipes.utils.dummy_vision_model",
+ "tokenizer._component_=torchtune.models.llama3_2_vision._transform.Llama3VisionTransform",
+ "tokenizer.patch_size=9",
+ "tokenizer.max_num_tiles=2",
+ "tokenizer.tile_size=18",
+ "tokenizer.max_seq_len=4096",
+ ]
+
+
+def dummy_vision_model():
+ from torchtune.models.llama3_2_vision._component_builders import (
+ llama3_2_vision_decoder,
+ llama3_2_vision_encoder,
+ )
+ from torchtune.modules.model_fusion import DeepFusionModel
+
+ vision_encoder = llama3_2_vision_encoder(
+ clip_embed_dim=128,
+ clip_num_layers=4,
+ num_heads=4,
+ tile_size=18,
+ patch_size=9,
+ max_num_tiles=2,
+ in_channels=3,
+ clip_hidden_states=[0, 1],
+ num_layers_projection=2,
+ decoder_embed_dim=128,
+ )
+ vision_decoder = llama3_2_vision_decoder(
+ vocab_size=128256,
+ num_layers=4,
+ fusion_interval=2,
+ num_special_tokens=2,
+ num_heads=8,
+ num_kv_heads=4,
+ embed_dim=128,
+ max_seq_len=4096,
+ encoder_max_seq_len=4096,
+ )
+
+ model = DeepFusionModel(
+ encoder=vision_encoder,
+ decoder=vision_decoder,
+ encoder_trainable=False,
+ decoder_trainable=False,
+ fusion_trainable=False,
+ )
+ return model
+
+
def lora_llama2_test_config(
lora_attn_modules,
apply_lora_to_mlp: bool = False,
@@ -199,6 +251,27 @@ def write_hf_ckpt_config(ckpt_dir: str):
json.dump(config, f)
+def write_hf_vision_ckpt_config(ckpt_dir: str):
+ config = {
+ "text_config": {
+ "num_attention_heads": 8,
+ "num_key_value_heads": 4,
+ "hidden_size": 128,
+ "vocab_size": 128256,
+ "cross_attention_layers": [1, 4],
+ },
+ "vision_config": {
+ "hidden_size": 128,
+ "image_size": 18,
+ "max_num_tiles": 2,
+ "supported_aspect_ratios": [[1, 1], [1, 2], [2, 1]],
+ },
+ }
+ config_file = Path.joinpath(Path(ckpt_dir), "config.json")
+ with config_file.open("w") as f:
+ json.dump(config, f)
+
+
MODEL_TEST_CONFIGS = {
"llama2": llama2_test_config(),
"llama3": llama3_test_config(),
diff --git a/tests/test_utils.py b/tests/test_utils.py
index bcb26285a1..1b820489df 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -33,6 +33,8 @@
"llama2_hf": "/tmp/test-artifacts/small-ckpt-hf-03082024.pt",
"llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt",
"llama3_tune": "/tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt",
+ "llama3_2_vision_hf": "/tmp/test-artifacts/small-ckpt-hf-vision-10172024.pt",
+ "llama3_2_vision_meta": "/tmp/test-artifacts/small-ckpt-meta-vision-10172024.pt",
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
}
diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py
index c40e89184b..cdb1d45f01 100644
--- a/torchtune/_recipe_registry.py
+++ b/torchtune/_recipe_registry.py
@@ -109,9 +109,6 @@ class Recipe:
Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"),
Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"),
Config(name="gemma/7B_full", file_path="gemma/7B_full.yaml"),
- Config(name="gemma2/2B_full", file_path="gemma2/2B_full.yaml"),
- Config(name="gemma2/9B_full", file_path="gemma2/9B_full.yaml"),
- Config(name="gemma2/27B_full", file_path="gemma2/27B_full.yaml"),
Config(name="phi3/mini_full", file_path="phi3/mini_full.yaml"),
Config(name="qwen2/7B_full", file_path="qwen2/7B_full.yaml"),
Config(name="qwen2/0.5B_full", file_path="qwen2/0.5B_full.yaml"),
@@ -219,30 +216,6 @@ class Recipe:
name="gemma/7B_qlora_single_device",
file_path="gemma/7B_qlora_single_device.yaml",
),
- Config(
- name="gemma2/2B_lora_single_device",
- file_path="gemma2/2B_lora_single_device.yaml",
- ),
- Config(
- name="gemma2/2B_qlora_single_device",
- file_path="gemma2/2B_qlora_single_device.yaml",
- ),
- Config(
- name="gemma2/9B_lora_single_device",
- file_path="gemma2/9B_lora_single_device.yaml",
- ),
- Config(
- name="gemma2/9B_qlora_single_device",
- file_path="gemma2/9B_qlora_single_device.yaml",
- ),
- Config(
- name="gemma2/27B_lora_single_device",
- file_path="gemma2/27B_lora_single_device.yaml",
- ),
- Config(
- name="gemma2/27B_qlora_single_device",
- file_path="gemma2/27B_qlora_single_device.yaml",
- ),
Config(
name="phi3/mini_lora_single_device",
file_path="phi3/mini_lora_single_device.yaml",
@@ -356,9 +329,6 @@ class Recipe:
Config(name="mistral/7B_lora", file_path="mistral/7B_lora.yaml"),
Config(name="gemma/2B_lora", file_path="gemma/2B_lora.yaml"),
Config(name="gemma/7B_lora", file_path="gemma/7B_lora.yaml"),
- Config(name="gemma2/2B_lora", file_path="gemma2/2B_lora.yaml"),
- Config(name="gemma2/9B_lora", file_path="gemma2/9B_lora.yaml"),
- Config(name="gemma2/27B_lora", file_path="gemma2/27B_lora.yaml"),
Config(name="phi3/mini_lora", file_path="phi3/mini_lora.yaml"),
Config(name="qwen2/7B_lora", file_path="qwen2/7B_lora.yaml"),
Config(name="qwen2/0.5B_lora", file_path="qwen2/0.5B_lora.yaml"),
diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py
index bb4b1ff0b0..c2d60a7373 100644
--- a/torchtune/generation/_generation.py
+++ b/torchtune/generation/_generation.py
@@ -67,7 +67,7 @@ def generate_next_token(
model: TransformerDecoder,
input_pos: torch.Tensor,
x: torch.Tensor,
- q: Optional[torch.Tensor] = None,
+ q: torch.Tensor,
*,
mask: Optional[torch.Tensor] = None,
temperature: float = 1.0,
@@ -82,7 +82,7 @@ def generate_next_token(
with shape [bsz x seq_length].
x (torch.Tensor): tensor with the token IDs associated with the given prompt,
with shape [bsz x seq_length].
- q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick.
+ q (torch.Tensor): randomly sampled tensor for softmax sampling trick.
See https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length],
default None.
@@ -302,11 +302,9 @@ def generate(
# tensors are of identical shape to the prompt
curr_masks = masks[:, :prompt_length, :prompt_length]
- q = None
- if rng is not None:
- q = torch.empty(
- (bsz, model.tok_embeddings.num_embeddings), device=prompt.device
- ).exponential_(1, generator=rng)
+ q = torch.empty(
+ (bsz, model.tok_embeddings.num_embeddings), device=prompt.device
+ ).exponential_(1, generator=rng)
tokens, generated_logits = generate_next_token(
model,
input_pos=input_pos[:, :prompt_length].squeeze(),
@@ -362,11 +360,9 @@ def generate(
curr_input_pos = input_pos[:, : curr_pos + 1]
curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1]
- q = None
- if rng is not None:
- q = torch.empty(
- (bsz, model.tok_embeddings.num_embeddings), device=prompt.device
- ).exponential_(1, generator=rng)
+ q = torch.empty(
+ (bsz, model.tok_embeddings.num_embeddings), device=prompt.device
+ ).exponential_(1, generator=rng)
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
diff --git a/torchtune/models/gemma/__init__.py b/torchtune/models/gemma/__init__.py
index f762de86b6..48e4e84b10 100644
--- a/torchtune/models/gemma/__init__.py
+++ b/torchtune/models/gemma/__init__.py
@@ -27,4 +27,6 @@
"lora_gemma_7b",
"qlora_gemma_2b",
"qlora_gemma_7b",
+ "gemma_hf_to_tune",
+ "gemma_tune_to_hf",
]
diff --git a/torchtune/models/gemma/_component_builders.py b/torchtune/models/gemma/_component_builders.py
index ba5b666c98..e7ab9b224c 100644
--- a/torchtune/models/gemma/_component_builders.py
+++ b/torchtune/models/gemma/_component_builders.py
@@ -46,6 +46,7 @@ def gemma(
attn_dropout: float = 0.0,
norm_eps: float = 1e-6,
rope_base: int = 10_000,
+ norm_embeddings: bool = True,
) -> TransformerDecoder:
"""
Build the decoder associated with the gemma model. This includes:
@@ -71,6 +72,8 @@ def gemma(
Default: 0.0
norm_eps (float): epsilon in RMS norms Default: 1e-6
rope_base (int): base for the rotary positional embeddings. Default: 10_000
+ norm_embeddings (bool): whether to apply layer norm before the self-attention
+ and mlp layers. Default: True
Returns:
TransformerDecoder: Instantiation of gemma model.
@@ -143,6 +146,7 @@ def lora_gemma(
attn_dropout: float = 0.0,
norm_eps: float = 1e-6,
rope_base: int = 10_000,
+ norm_embeddings: bool = True,
# LoRA args
lora_rank: int,
lora_alpha: float,
@@ -173,6 +177,8 @@ def lora_gemma(
Default: 0.0
norm_eps (float): epsilon in RMS norms Default: 1e-6
rope_base (int): base for the rotary positional embeddings. Default: 10_000
+ norm_embeddings (bool): whether to apply layer norm before the self-attention
+ and mlp layers. Default: True
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
diff --git a/torchtune/models/gemma2/__init__.py b/torchtune/models/gemma2/__init__.py
deleted file mode 100644
index 9fe11db7ab..0000000000
--- a/torchtune/models/gemma2/__init__.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from ..gemma._model_builders import gemma_tokenizer
-from ..gemma._tokenizer import GemmaTokenizer # noqa
-from ._component_builders import gemma2, lora_gemma2 # noqa
-from ._model_builders import ( # noqa
- gemma2_27b,
- gemma2_2b,
- gemma2_9b,
- lora_gemma2_27b,
- lora_gemma2_2b,
- lora_gemma2_9b,
- qlora_gemma2_27b,
- qlora_gemma2_2b,
- qlora_gemma2_9b,
-)
-
-__all__ = [
- "GemmaTokenizer",
- "gemma2",
- "gemma2_2b",
- "gemma2_9b",
- "gemma2_27b",
- "gemma_tokenizer",
- "lora_gemma2",
- "lora_gemma2_2b",
- "lora_gemma2_9b",
- "lora_gemma2_27b",
- "qlora_gemma2_2b",
- "qlora_gemma2_9b",
- "qlora_gemma2_27b",
-]
diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py
deleted file mode 100644
index 0ddef36857..0000000000
--- a/torchtune/models/gemma2/_component_builders.py
+++ /dev/null
@@ -1,413 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from torch import nn
-import torch
-from typing import List
-from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks
-from typing import List, Optional
-
-from torchtune.modules import (
- FrozenNF4Linear,
- RotaryPositionalEmbeddings,
- TransformerSelfAttentionLayer,
-)
-
-from torchtune.models.gemma2._attention import Gemma2Attention
-from torchtune.models.gemma.rms_norm import GemmaRMSNorm
-from torchtune.modules import TransformerDecoder, TiedLinear
-from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings
-from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
-from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp
-
-"""
-Component builders for the Gemma2 2B, 9B models and popular variants such as LoRA.
-
-torchtune provides composable building blocks. Builder functions help
-stitch these building blocks into higher-level components. This design has
-two benefits:
-- The building blocks themselves are very flexible. For example, ``MultiHeadAttention``
-can take either nn.Linear or nn.LoRALinear for ``q_proj``.
-- Builder functions expose a set of configurable params which keep the constructors of
-the building blocks simple.
-"""
-
-class TanhSoftCapping(nn.Module):
- def __init__(
- self,
- capping_value: float,
- ) -> None:
- super().__init__()
- self.capping_value = capping_value
-
- def forward(self, attn_weights):
- attn_weights = attn_weights / self.capping_value
- attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * self.capping_value
- return attn_weights
-
-class Gemma2FinalNorm(nn.Module):
- """
- Combines RMSNorm and SoftCapping
- """
- def __init__(
- self,
- capping_value: float,
- embed_dim: int,
- eps: float
- ) -> None:
- super().__init__()
- self.capping_value = capping_value
- self.rms_norm = GemmaRMSNorm(embed_dim, eps=eps)
- self.logit_capping = TanhSoftCapping(capping_value)
-
- def forward(self, x):
- x = self.rms_norm(x)
- x = self.logit_capping(x)
- return x
-
-
-def gemma2(
- vocab_size: int,
- num_layers: int,
- num_heads: int,
- head_dim: int,
- num_kv_heads: int,
- embed_dim: int,
- intermediate_dim: int,
- max_seq_len: int,
- attn_dropout: float = 0.0,
- norm_eps: float = 1e-6,
- rope_base: int = 10_000,
- hidden_capping_value: float = 50.,
- final_capping_value: float = 30.,
- sliding_window_size: int = 4096,
- query_pre_attn_scalar: Optional[int] = None,
-) -> TransformerDecoder:
- """
- Build the decoder associated with the gemma2 model. This includes:
- - Token embeddings
- - num_layers number of TransformerSelfAttentionLayer blocks
- - RMS Norm layer applied to the output of the transformer
- - Final projection into token space
-
-
- Args:
- vocab_size (int): number of tokens in vocabulary.
- num_layers (int): number of layers in the transformer decoder.
- num_heads (int): number of query heads. For MHA this is also the
- number of heads for key and value
- head_dim (int): dimension of head
- num_kv_heads (int): number of key and value heads.
- embed_dim (int): embedding dimension for self-attention
- intermediate_dim (int): intermediate dimension for MLP
- max_seq_len (int): maximum sequence length the model will be run with,
- attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
- Default: 0.0
- norm_eps (float): epsilon in RMS norms Default: 1e-6
- rope_base (int): base for the rotary positional embeddings. Default: 10_000
-
- Returns:
- TransformerDecoder: Instantiation of gemma model.
- """
- rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
-
- layers = torch.nn.ModuleList()
-
- for layer_idx in range(num_layers):
-
- mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
-
- self_att = Gemma2Attention(
- embed_dim=embed_dim,
- num_heads=num_heads,
- num_kv_heads=num_kv_heads,
- head_dim=head_dim,
- q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
- k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
- v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
- output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False),
- pos_embeddings=rope,
- kv_cache=None,
- max_seq_len=max_seq_len,
- attn_dropout=attn_dropout,
- # perform sliding window on half of the layers only
- sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None,
- softcapping=hidden_capping_value,
- query_pre_attn_scalar=query_pre_attn_scalar
- )
-
- layer = TransformerSelfAttentionLayer(
- attn=self_att,
- mlp=mlp,
- sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
- mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
- sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps),
- mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps),
- )
- layers.append(layer)
- tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
- output_proj = TiedLinear(tok_embeddings)
- model = TransformerDecoder(
- tok_embeddings=tok_embeddings,
- layers=layers,
- max_seq_len=max_seq_len,
- num_heads=num_heads,
- output=output_proj,
- head_dim=head_dim,
- norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps),
- )
- return model
-
-
-
-def lora_gemma2(
- lora_attn_modules: List[LORA_ATTN_MODULES],
- apply_lora_to_mlp: bool = False,
- *,
- # gemma args
- vocab_size: int,
- num_layers: int,
- num_heads: int,
- head_dim: int,
- num_kv_heads: int,
- embed_dim: int,
- intermediate_dim: int,
- max_seq_len: int,
- attn_dropout: float = 0.0,
- norm_eps: float = 1e-6,
- rope_base: int = 10_000,
- hidden_capping_value: float = 50.,
- final_capping_value: float = 30.,
- sliding_window_size: int = 4096,
- query_pre_attn_scalar: Optional[int] = None,
- # LoRA args
- lora_rank: int,
- lora_alpha: float,
- lora_dropout: float = 0.0,
- use_dora: bool = False,
- quantize_base: bool = False,
-) -> TransformerDecoder:
- """
- Return a version of Gemma with LoRA applied based on the passed in configuration.
- Note: output projection lora is not supported because it is tied to token embeddings
-
- Args:
- lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
- LoRA should be applied to in each self-attention block. Options are
- ``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
- apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
- Default: False
- vocab_size (int): number of tokens in vocabulary.
- num_layers (int): number of layers in the transformer decoder.
- num_heads (int): number of query heads. For MHA this is also the
- number of heads for key and value
- head_dim (int): dimension of head
- num_kv_heads (int): number of key and value heads.
- embed_dim (int): embedding dimension for self-attention
- intermediate_dim (int): intermediate dimension for MLP
- max_seq_len (int): maximum sequence length the model will be run with,
- attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
- Default: 0.0
- norm_eps (float): epsilon in RMS norms Default: 1e-6
- rope_base (int): base for the rotary positional embeddings. Default: 10_000
- lora_rank (int): rank of each low-rank approximation
- lora_alpha (float): scaling factor for the low-rank approximation
- lora_dropout (float): LoRA dropout probability. Default: 0.0
- use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
- introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
- quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
- weights within linear layers LoRA is applied to. The final output linear projection is not
- supported for quantization currently.
-
- Returns:
- TransformerDecoder: Instantiation of Gemma model with LoRA applied to
- a subset of the attention projections in each layer.
- """
-
- tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
- output_proj = TiedLinear(tok_embeddings)
-
- layers = torch.nn.ModuleList()
-
- for layer_idx in range(num_layers):
- if apply_lora_to_mlp:
- mlp = lora_gemma_mlp(
- dim=embed_dim,
- hidden_dim=intermediate_dim,
- lora_rank=lora_rank,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- use_dora=use_dora,
- quantize_base=quantize_base,
- )
- else:
- mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base)
- self_att = lora_gemma2_self_attention(
- lora_modules=lora_attn_modules,
- embed_dim=embed_dim,
- num_heads=num_heads,
- num_kv_heads=num_kv_heads,
- head_dim=head_dim,
- rope_base=rope_base,
- max_seq_len=max_seq_len,
- attn_dropout=attn_dropout,
- # perform sliding window on half of the layers only
- sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None,
- softcapping=hidden_capping_value,
- query_pre_attn_scalar=query_pre_attn_scalar,
- lora_rank=lora_rank,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- use_dora = use_dora,
- quantize_base = quantize_base,
- )
-
- layer = TransformerSelfAttentionLayer(
- attn=self_att,
- mlp=mlp,
- sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
- mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
- sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps),
- mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps),
- )
- layers.append(layer)
-
- model = TransformerDecoder(
- tok_embeddings=tok_embeddings,
- layers=layers,
- max_seq_len=max_seq_len,
- num_heads=num_heads,
- output=output_proj,
- head_dim=head_dim,
- norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps)
- )
-
- if quantize_base:
- # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly
- # so as to not increase peak memory
- # TODO this is clowny, figure out a better way to get what precision the rest
- # of the model is in
- _register_reparametrize_state_dict_hooks(model, dtype=tok_embeddings.weight.dtype)
-
- return model
-
-
-def lora_gemma2_self_attention(
- lora_modules: List[LORA_ATTN_MODULES],
- *,
- # MultiHeadAttention args
- embed_dim: int,
- num_heads: int,
- head_dim: int,
- num_kv_heads: int,
- max_seq_len: int,
- attn_dropout: float = 0.0,
- rope_base: int = 10_000,
- sliding_window_size: Optional[int] = None,
- softcapping: Optional[float] = 50.,
- query_pre_attn_scalar: Optional[int],
- # LoRA args
- lora_rank: int,
- lora_alpha: float,
- lora_dropout: float = 0.0,
- use_dora: bool = False,
- quantize_base: bool = False,
-
-) -> Gemma2Attention:
- if not lora_modules:
- raise ValueError(
- f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules"
- )
-
- num_kv_heads = num_kv_heads if num_kv_heads else num_heads
- adapter_cls = DoRALinear if use_dora else LoRALinear
-
- q_proj = (
- adapter_cls(
- embed_dim,
- num_heads * head_dim,
- rank=lora_rank,
- alpha=lora_alpha,
- dropout=lora_dropout,
- quantize_base=quantize_base,
- )
- if "q_proj" in lora_modules
- else (
- nn.Linear(embed_dim, num_heads * head_dim, bias=False)
- if not quantize_base
- else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False)
- )
- )
- k_proj = (
- adapter_cls(
- embed_dim,
- num_kv_heads * head_dim,
- rank=lora_rank,
- alpha=lora_alpha,
- dropout=lora_dropout,
- quantize_base=quantize_base,
- )
- if "k_proj" in lora_modules
- else (
- nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
- if not quantize_base
- else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
- )
- )
- v_proj = (
- adapter_cls(
- embed_dim,
- num_kv_heads * head_dim,
- rank=lora_rank,
- alpha=lora_alpha,
- dropout=lora_dropout,
- quantize_base=quantize_base,
- )
- if "v_proj" in lora_modules
- else (
- nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
- if not quantize_base
- else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
- )
- )
- output_proj = (
- adapter_cls(
- num_heads * head_dim,
- embed_dim,
- rank=lora_rank,
- alpha=lora_alpha,
- dropout=lora_dropout,
- quantize_base=quantize_base,
- )
- if "output_proj" in lora_modules
- else (
- nn.Linear(num_heads * head_dim, embed_dim, bias=False)
- if not quantize_base
- else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False)
- )
- )
-
- rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
-
- self_att = Gemma2Attention(
- embed_dim=embed_dim,
- num_heads=num_heads,
- num_kv_heads=num_kv_heads,
- head_dim=head_dim,
- q_proj=q_proj,
- k_proj=k_proj,
- v_proj=v_proj,
- output_proj=output_proj,
- pos_embeddings=rope,
- kv_cache=None,
- max_seq_len=max_seq_len,
- attn_dropout=attn_dropout,
- sliding_window_size=sliding_window_size,
- softcapping=softcapping,
- query_pre_attn_scalar=query_pre_attn_scalar
- )
- return self_att
\ No newline at end of file
diff --git a/torchtune/models/gemma2/_convert_weights.py b/torchtune/models/gemma2/_convert_weights.py
deleted file mode 100644
index fa4df0e469..0000000000
--- a/torchtune/models/gemma2/_convert_weights.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Dict
-
-import torch
-
-from torchtune.models.convert_weights import get_mapped_key
-
-"""
-Gemma 2 and Gemma original implementations share different normalization but with
-the same name, so it is mandatory to differentiate their state dict in order to map
-correctly the different weights.
-They are essentially the same except for "model.layers.{}.post_attention_layernorm.weight" key.
-See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251
-"""
-
-_GEMMA2_FROM_HF = {
- "model.embed_tokens.weight": "tok_embeddings.weight",
- "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
- "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
- "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
- "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
- "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
- "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
- "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
- "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
- "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
- "model.layers.{}.post_attention_layernorm.weight": "layers.{}.sa_scale.scale",
- "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.mlp_norm.scale",
- "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.mlp_scale.scale",
- "model.norm.weight": "norm.rms_norm.scale",
- "lm_head.weight": "output.weight",
-}
-
-
-def gemma2_hf_to_tune(
- state_dict: Dict[str, torch.Tensor],
- num_heads: int = 32,
- num_kv_heads: int = 32,
- dim: int = 4096,
- head_dim: int = None,
-) -> Dict[str, torch.Tensor]:
- """
- Convert a state dict from HF's format to torchtune's format. State dicts
- from multiple checkpoint files should be consolidated into a single state dict
- before calling this function.
-
- Eg of HF-format state dict can be found in the ``meta-llama/Llama-2-7b-hf``
- repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf).
-
- Args:
- state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
- num_heads (int): Number of heads in the model.
- num_kv_heads (int): Number of heads in the key/value projection layers.
- dim (int): Dimension of the model.
- head_dim (int): Dimension of the head. If not provided, it will be calculated
- as dim // num_heads.
-
- Returns:
- Dict[str, torch.Tensor]: State dict in torchtune's format.
- """
- converted_state_dict = {}
- if head_dim is None:
- head_dim = dim // num_heads
-
- def _permute(t, n_heads):
- return (
- t.view(n_heads, 2, head_dim // 2, dim)
- .transpose(1, 2)
- .reshape((head_dim * n_heads), dim)
- )
-
- for key, value in state_dict.items():
- if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings
- new_key = get_mapped_key(key, _GEMMA2_FROM_HF)
- if "q_proj" in key:
- value = _permute(value, num_heads)
- elif "k_proj" in key:
- value = _permute(value, num_kv_heads)
-
- converted_state_dict[new_key] = value
- return converted_state_dict
-
-
-def gemma2_tune_to_hf(
- state_dict: Dict[str, torch.Tensor],
- num_heads: int = 32,
- num_kv_heads: int = 32,
- dim: int = 4096,
- head_dim: int = None,
-):
- """
- Convert a state dict from torchtune's format to HF's format. This function
- doesn't handle any sharding or splitting of state dicts. It follows the
- state_dict IN -> state_dict OUT pattern.
-
- Args:
- state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
- num_heads (int): Number of heads in the model.
- num_kv_heads (int): Number of heads in the key/value projection layers.
- dim (int): Dimension of the model.
- head_dim (int): Dimension of model attention heads. Default None.
-
- Returns:
- Dict[str, torch.Tensor]: State dict in HF's format.
- """
- converted_state_dict = {}
- inverted_mapping_dict = {v: k for k, v in _GEMMA2_FROM_HF.items()}
-
- if head_dim is None:
- head_dim = dim // num_heads
-
- def _permute(t, n_heads):
- return (
- t.view(n_heads, head_dim // 2, 2, dim)
- .transpose(1, 2)
- .reshape((head_dim * n_heads), dim)
- )
-
- for key, value in state_dict.items():
- new_key = get_mapped_key(key, inverted_mapping_dict)
- if "q_proj" in key:
- value = _permute(value, num_heads)
- elif "k_proj" in key:
- value = _permute(value, num_kv_heads)
- converted_state_dict[new_key] = value
-
- return converted_state_dict
diff --git a/torchtune/models/gemma2/_model_builders.py b/torchtune/models/gemma2/_model_builders.py
deleted file mode 100644
index a07021c518..0000000000
--- a/torchtune/models/gemma2/_model_builders.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-from typing import List
-
-from torchtune.models.gemma2._component_builders import gemma2, lora_gemma2
-from torchtune.modules import TransformerDecoder
-
-from torchtune.modules.peft import LORA_ATTN_MODULES
-from functools import partial
-
-"""
-Model builders build specific instantiations using component builders. For example
-the ``gemma_2b`` model builder uses the ``gemma2`` component builder.
-"""
-
-
-def gemma2_2b() -> TransformerDecoder:
- """
- Builder for creating a Gemma2 2B model initialized w/ the default 2b parameter values
- from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py
-
- Returns:
- TransformerDecoder: Instantiation of Gemma2 2B model
- """
- return gemma2(
- vocab_size=256_000,
- num_layers=26,
- num_heads=8,
- head_dim=256,
- num_kv_heads=4,
- embed_dim=2304,
- intermediate_dim=9216,
- max_seq_len=8192,
- attn_dropout=0.0,
- norm_eps=1e-6,
- hidden_capping_value=30.0,
- final_capping_value=50.0,
- sliding_window_size=4096,
- )
-
-
-def lora_gemma2_2b(
- lora_attn_modules: List[LORA_ATTN_MODULES],
- apply_lora_to_mlp: bool = False,
- lora_rank: int = 8,
- lora_alpha: float = 16,
- lora_dropout: float = 0.0,
- use_dora: bool = False,
- quantize_base: bool = False,
-) -> TransformerDecoder:
- """
- Builder for creating a Gemma2 2B model with LoRA enabled.
-
- The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_2b`,
- while LoRA default params are based on
- https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
-
- Args:
- lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
- LoRA should be applied to in each self-attention block. Options are
- ``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
- apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
- Default: False
- lora_rank (int): rank of each low-rank approximation
- lora_alpha (float): scaling factor for the low-rank approximation
- lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0
- use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
- introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
- quantize_base (bool): Whether to quantize base model weights
-
- Returns:
- TransformerDecoder: Instantiation of Gemma2 2B model with LoRA applied
- """
- return lora_gemma2(
- lora_attn_modules=lora_attn_modules,
- apply_lora_to_mlp=apply_lora_to_mlp,
- vocab_size=256_000,
- num_layers=26,
- num_heads=8,
- head_dim=256,
- num_kv_heads=4,
- embed_dim=2304,
- intermediate_dim=9216,
- max_seq_len=8192,
- attn_dropout=0.0,
- norm_eps=1e-6,
- hidden_capping_value=30.0,
- final_capping_value=50.0,
- sliding_window_size=4096,
- lora_rank=lora_rank,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- use_dora=use_dora,
- quantize_base=quantize_base,
- )
-
-qlora_gemma2_2b = partial(lora_gemma2_2b, quantize_base=True)
-
-qlora_gemma2_2b.__doc__ = """
-Builder for creating a Gemma2 model with QLoRA enabled. Base model weights in linear layers
-that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
-Please see `lora_gemm2a_2b` for full API arguments.
-"""
-
-
-
-def gemma2_9b() -> TransformerDecoder:
- """
- Builder for creating a Gemma2 9B model initialized w/ the default 9b parameter values
- from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py
-
- Returns:
- TransformerDecoder: Instantiation of Gemma 9B model
- """
- return gemma2(
- vocab_size=256_000,
- num_layers=42,
- num_heads=16,
- head_dim=256,
- num_kv_heads=8,
- embed_dim=3584,
- intermediate_dim=14336,
- max_seq_len=8192,
- attn_dropout=0.0,
- norm_eps=1e-6,
- hidden_capping_value=30.0,
- final_capping_value=50.0,
- sliding_window_size=4096,
- )
-
-
-def lora_gemma2_9b(
- lora_attn_modules: List[LORA_ATTN_MODULES],
- apply_lora_to_mlp: bool = False,
- lora_rank: int = 8,
- lora_alpha: float = 16,
- lora_dropout: float = 0.0,
- use_dora: bool = False,
- quantize_base: bool = False,
-) -> TransformerDecoder:
- """
- Builder for creating a Gemma 9B model with LoRA enabled.
-
- The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_7b`,
- while LoRA default params are based on
- https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
-
- Args:
- lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
- LoRA should be applied to in each self-attention block. Options are
- ``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
- apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
- Default: False
- lora_rank (int): rank of each low-rank approximation
- lora_alpha (float): scaling factor for the low-rank approximation
- lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0
- use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
- introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
- quantize_base (bool): Whether to quantize base model weights
-
- Returns:
- TransformerDecoder: Instantiation of Gemma2 9B model with LoRA applied
- """
- return lora_gemma2(
- lora_attn_modules=lora_attn_modules,
- apply_lora_to_mlp=apply_lora_to_mlp,
- vocab_size=256_000,
- num_layers=42,
- num_heads=16,
- head_dim=256,
- num_kv_heads=8,
- embed_dim=3584,
- intermediate_dim=14336,
- max_seq_len=8192,
- attn_dropout=0.0,
- norm_eps=1e-6,
- hidden_capping_value=30.0,
- final_capping_value=50.0,
- sliding_window_size=4096,
- lora_rank=lora_rank,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- use_dora=use_dora,
- quantize_base=quantize_base,
- )
-
-qlora_gemma2_9b = partial(lora_gemma2_9b, quantize_base=True)
-
-qlora_gemma2_9b.__doc__ = """
-Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers
-that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
-Please see `lora_gemma2_9b` for full API arguments.
-"""
-
-def gemma2_27b() -> TransformerDecoder:
- """
- Builder for creating a Gemma2 27B model initialized w/ the default 27b parameter values
- from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py
-
- Returns:
- TransformerDecoder: Instantiation of Gemma2 27B model
- """
- return gemma2(
- vocab_size=256_000,
- num_layers=46,
- num_heads=32,
- head_dim=128,
- num_kv_heads=16,
- embed_dim=4608,
- intermediate_dim=36864,
- max_seq_len=8192,
- attn_dropout=0.0,
- norm_eps=1e-6,
- hidden_capping_value=30.0,
- final_capping_value=50.0,
- sliding_window_size=4096,
- query_pre_attn_scalar=144,
- )
-
-
-def lora_gemma2_27b(
- lora_attn_modules: List[LORA_ATTN_MODULES],
- apply_lora_to_mlp: bool = False,
- lora_rank: int = 8,
- lora_alpha: float = 16,
- lora_dropout: float = 0.0,
- use_dora: bool = False,
- quantize_base: bool = False,
-) -> TransformerDecoder:
- """
- Builder for creating a Gemma2 27B model with LoRA enabled.
-
- The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_7b`,
- while LoRA default params are based on
- https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
-
- Args:
- lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
- LoRA should be applied to in each self-attention block. Options are
- ``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
- apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
- Default: False
- lora_rank (int): rank of each low-rank approximation
- lora_alpha (float): scaling factor for the low-rank approximation
- lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0
- use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
- introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
- quantize_base (bool): Whether to quantize base model weights
-
- Returns:
- TransformerDecoder: Instantiation of Gemma2 27B model with LoRA applied
- """
- return lora_gemma2(
- lora_attn_modules=lora_attn_modules,
- apply_lora_to_mlp=apply_lora_to_mlp,
- vocab_size=256_000,
- num_layers=46,
- num_heads=32,
- head_dim=128,
- num_kv_heads=16,
- embed_dim=4608,
- intermediate_dim=36864,
- max_seq_len=8192,
- attn_dropout=0.0,
- norm_eps=1e-6,
- hidden_capping_value=30.0,
- final_capping_value=50.0,
- sliding_window_size=4096,
- query_pre_attn_scalar=144,
- lora_rank=lora_rank,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- use_dora=use_dora,
- quantize_base=quantize_base,
- )
-
-qlora_gemma2_27b = partial(lora_gemma2_27b, quantize_base=True)
-
-qlora_gemma2_27b.__doc__ = """
-Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers
-that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
-Please see `lora_gemma2_27b` for full API arguments.
-"""
diff --git a/torchtune/models/llama3_2_vision/_component_builders.py b/torchtune/models/llama3_2_vision/_component_builders.py
index 8881c87531..02f05eae35 100644
--- a/torchtune/models/llama3_2_vision/_component_builders.py
+++ b/torchtune/models/llama3_2_vision/_component_builders.py
@@ -170,6 +170,7 @@ def llama3_2_vision_decoder(
by :func:`~torchtune.modules.KVCache`.
encoder_max_seq_len (int): maximum sequence length the encoder will be run with, as used
by :func:`~torchtune.modules.KVCache`.
+ rope_base (int): base for the rotary positional embeddings. Default: 500_000
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`.
diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py
index ec79f0a4ba..8255fdac8c 100644
--- a/torchtune/training/checkpointing/_checkpointer.py
+++ b/torchtune/training/checkpointing/_checkpointer.py
@@ -488,16 +488,6 @@ def load_checkpoint(self) -> Dict[str, Any]:
"supported_aspect_ratios", None
),
)
- elif self._model_type == ModelType.GEMMA2:
- from torchtune.models.gemma2._convert_weights import gemma2_hf_to_tune
-
- converted_state_dict[training.MODEL_KEY] = gemma2_hf_to_tune(
- merged_state_dict,
- num_heads=self._config["num_attention_heads"],
- num_kv_heads=self._config["num_key_value_heads"],
- dim=self._config["hidden_size"],
- head_dim=self._config.get("head_dim", None),
- )
else:
converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
@@ -588,16 +578,6 @@ def save_checkpoint(
"supported_aspect_ratios", None
),
)
- elif self._model_type == ModelType.GEMMA2:
- from torchtune.models.gemma2._convert_weights import gemma2_tune_to_hf
-
- state_dict[training.MODEL_KEY] = gemma2_tune_to_hf(
- state_dict[training.MODEL_KEY],
- num_heads=self._config["num_attention_heads"],
- num_kv_heads=self._config["num_key_value_heads"],
- dim=self._config["hidden_size"],
- head_dim=self._config.get("head_dim", None),
- )
else:
state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[training.MODEL_KEY],
diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py
index 2fa7265194..2d353b007c 100644
--- a/torchtune/training/checkpointing/_utils.py
+++ b/torchtune/training/checkpointing/_utils.py
@@ -45,7 +45,6 @@ class ModelType(Enum):
Attributes:
GEMMA (str): Gemma family of models. See :func:`~torchtune.models.gemma.gemma`
- GEMMA2 (str): Gemma 2 family of models. See :func:`~torchtune.models.gemma2.gemma2`
LLAMA2 (str): Llama2 family of models. See :func:`~torchtune.models.llama2.llama2`
LLAMA3 (str): Llama3 family of models. See :func:`~torchtune.models.llama3.llama3`
LLAMA3_2 (str): Llama3.2 family of models. See :func:`~torchtune.models.llama3_2.llama3_2`
@@ -66,7 +65,6 @@ class ModelType(Enum):
"""
GEMMA: str = "gemma"
- GEMMA2: str = "gemma2"
LLAMA2: str = "llama2"
LLAMA3: str = "llama3"
LLAMA3_2: str = "llama3_2"