Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Early Exit Loss and/or Layer Dropout #1076

Merged
merged 90 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
97cb9a8
start of layer dropout implementation
mostafaelhoushi Jun 9, 2024
4a25c5b
have different dropouts at different layers
mostafaelhoushi Jun 9, 2024
ac8ad0b
add option to specify which layers to apply dropout
mostafaelhoushi Jun 9, 2024
ae61c85
start early exit loss
mostafaelhoushi Jun 10, 2024
735d2a8
parallelize processing of early exit losses
mostafaelhoushi Jun 19, 2024
be912a6
use absolute imports
mostafaelhoushi Jun 19, 2024
0686dd2
remove unnecessary sync
mostafaelhoushi Jun 19, 2024
4e4783f
move early exit loss to separate file and add layers as arg
mostafaelhoushi Jun 19, 2024
268813e
perform loss scaling every iteration
mostafaelhoushi Jun 19, 2024
ccb4a50
return hidden states as an output rather than storing
mostafaelhoushi Jun 19, 2024
ff7d157
ensure last layer is always included
mostafaelhoushi Jun 20, 2024
5a23811
return either last logits or hidden states
mostafaelhoushi Jun 20, 2024
e11aeba
fix scaling layers
mostafaelhoushi Jun 20, 2024
f9e164f
rotational early exit curriculum
mostafaelhoushi Jun 20, 2024
069b661
set early exit params from cli
mostafaelhoushi Jul 22, 2024
954d097
ensure last layer loss is always calculated
mostafaelhoushi Jul 22, 2024
5789745
implement gradual early exit
mostafaelhoushi Jul 22, 2024
c3534e6
get streaming to work
mostafaelhoushi Jul 22, 2024
d1c6963
Merge branch 'main' into layerskip
mostafaelhoushi Nov 12, 2024
7849130
add separate recipe for early exit
mostafaelhoushi Nov 13, 2024
df89c4f
port early exit loss code from PR
mostafaelhoushi Nov 13, 2024
6cedb19
convert boolean array to indices
mostafaelhoushi Nov 17, 2024
a83da5a
decide on hidden outputs by member variable not forward pass
mostafaelhoushi Nov 17, 2024
2a8791d
add early exit recipe config
mostafaelhoushi Nov 17, 2024
a326937
refactor unembedding
mostafaelhoushi Nov 17, 2024
8ba6ab4
got early exit loss to work
mostafaelhoushi Nov 18, 2024
681e7ca
add TopV2 instruction set
mostafaelhoushi Nov 19, 2024
119ac7d
ensure all early exit loss params from cfg file are passed to code
mostafaelhoushi Nov 19, 2024
3ec9d23
fix gradual early exit
mostafaelhoushi Nov 19, 2024
04a590f
add test cases for early exit loss
mostafaelhoushi Nov 19, 2024
9b5c96a
add more assertions for rotational early exit
mostafaelhoushi Nov 19, 2024
3319ab0
test to follow training code
mostafaelhoushi Nov 19, 2024
619b3eb
fix curriculum update
mostafaelhoushi Nov 20, 2024
d376ddd
update recipe
mostafaelhoushi Nov 21, 2024
ff3977b
reset changes to data loading
mostafaelhoushi Nov 21, 2024
75b2e01
code cleanup
mostafaelhoushi Nov 23, 2024
33a95f5
rename early_exit to early_exit_loss
mostafaelhoushi Nov 23, 2024
5d7e903
address some early exit TODOs
mostafaelhoushi Nov 23, 2024
87f2ee0
get layer dropout to work
mostafaelhoushi Nov 23, 2024
1de0c2a
clean up early exit curriculum
mostafaelhoushi Nov 24, 2024
2b0cdd1
enable grad curriculum for subset of layers + clear hidden_states at …
mostafaelhoushi Nov 24, 2024
7973459
add docstring for slice_str_to_array
mostafaelhoushi Nov 24, 2024
baed8a9
support commas and add assertion statements
mostafaelhoushi Nov 24, 2024
27f6b56
add test cases for slice_to_str_array
mostafaelhoushi Nov 24, 2024
63e7c5b
add copyright header
mostafaelhoushi Nov 24, 2024
638056b
support single index
mostafaelhoushi Nov 24, 2024
a20b07c
add new line at end of file
mostafaelhoushi Nov 24, 2024
64210e6
Merge branch 'main' into layerskip
mostafaelhoushi Nov 24, 2024
98897a8
add layer dropout test cases
mostafaelhoushi Nov 24, 2024
2cc94cc
rename apply_layer_dropout to prepare_layer_dropout
mostafaelhoushi Nov 24, 2024
f4f8e02
add test cases for get_scale
mostafaelhoushi Nov 24, 2024
fed955e
cleanup get_scale + re-write mathematically equivalent + ensure max s…
mostafaelhoushi Nov 24, 2024
ca7d8da
test layer_dropout
mostafaelhoushi Nov 24, 2024
0146764
start adding early exit loss and layer dropout to docstring
mostafaelhoushi Nov 24, 2024
f599eca
fix and update code and test cases to handle updating last layer sepa…
mostafaelhoushi Nov 24, 2024
2437092
change match to if-else for CI
mostafaelhoushi Nov 24, 2024
ad090af
add assertion on type of loss fn for early exit loss
mostafaelhoushi Nov 25, 2024
cec8cd4
add docstring and slightly change attribute of layer_dropout and earl…
mostafaelhoushi Nov 25, 2024
b69f2f3
refactor layer_dropout and add test cases on wrapper
mostafaelhoushi Nov 25, 2024
a21cbd3
add TODO comment
mostafaelhoushi Nov 25, 2024
eb37cb6
fix error in checking if early exit loss is enabled
mostafaelhoushi Nov 25, 2024
2e3f502
change recipe defaults of dataset and layer_drop probability
mostafaelhoushi Nov 26, 2024
66a41b2
add detailed docstring to training script
mostafaelhoushi Nov 26, 2024
345a0a3
ensure we set last layer early exit enable correctly
mostafaelhoushi Nov 26, 2024
20c618c
ensure uniform early exit loss works
mostafaelhoushi Nov 26, 2024
f0e8d7f
add documentation to .yaml file and update doc in .py
mostafaelhoushi Nov 26, 2024
b03cb57
remove commented lines
mostafaelhoushi Nov 27, 2024
199b8dd
remove check on PyTorch version since we assume latest stable PyTorch
mostafaelhoushi Nov 27, 2024
6a2d79b
load curriculum step when resuming
mostafaelhoushi Nov 27, 2024
e5534ea
repeat arguments in derived classes
mostafaelhoushi Nov 27, 2024
d270d1f
rename percent_scale to fraction_scale and change its implementation
mostafaelhoushi Nov 27, 2024
e51419c
fixes to docstrings and config examples
mostafaelhoushi Dec 1, 2024
40b7987
check if cfg_early_exit_loss has curriculum
mostafaelhoushi Dec 1, 2024
0c18595
add comment to explain when has no effect
mostafaelhoushi Dec 1, 2024
3e68696
organize early exit loss tests into classes
mostafaelhoushi Dec 1, 2024
418951b
fix typo
mostafaelhoushi Dec 1, 2024
e5a53f9
test all loss scale types
mostafaelhoushi Dec 1, 2024
3567a24
use variable number of subset layers
mostafaelhoushi Dec 1, 2024
ae2108d
ensure get_scale returns values between 0 and 1
mostafaelhoushi Dec 1, 2024
71707de
add test cases for sigmoid
mostafaelhoushi Dec 2, 2024
78aff5a
make prepare_layer_dropout apply on a list of layers rather than a model
mostafaelhoushi Dec 2, 2024
0fb373b
Only add `optional` in docstring when argument is optional
mostafaelhoushi Dec 4, 2024
b66e23b
add Dropout class and prepare_layer_dropout APIs to docs
mostafaelhoushi Dec 4, 2024
cd8be64
add empty line between function description and Args
mostafaelhoushi Dec 4, 2024
2675b4c
remove assert statement as we added the check in testing
mostafaelhoushi Dec 4, 2024
00d8efa
change loss scale from enum to function
mostafaelhoushi Dec 5, 2024
78b8996
change curriculum from enum to function
mostafaelhoushi Dec 5, 2024
ed33ba9
rename scale_type to scale_fn
mostafaelhoushi Dec 6, 2024
c7f02de
change default
mostafaelhoushi Dec 6, 2024
69f840c
update docstring
mostafaelhoushi Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Modeling Components and Building Blocks
TransformerCrossAttentionLayer
TransformerDecoder
VisionTransformer
LayerDropout
prepare_layer_dropout

Losses
------
Expand Down
137 changes: 137 additions & 0 deletions recipes/dev/7B_full_early_exit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Config for multi-device full finetuning with early exit loss and/or layer dropout
# in dev/early_exit_finetune_distributed.py using a Llama2 7B model on a small TOPv2
# instruction set.
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# To reproduce experiments of various papers that use early exit loss and/or layer dropout:
# - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2:
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp
#
# - LITE (https://arxiv.org/abs/2310.18581):
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5
#
# - LayerDrop (https://arxiv.org/abs/1909.11556):
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2
#
# - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) (The paper also implements a curriculum for layer drop probability which is not yet implemented.):
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.5 layer_dropout.scale=exp
#
# This config works best for distributed training, hence when the model is being fine-tuned on 2+ GPUs.
#


# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.datasets.instruct_dataset
source: WillHeld/top_v2
split: train
column_map:
input: utterance
output: semantic_parse

seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 8
epochs: 1
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/topv2-llama2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1

# Early Exit Loss
early_exit_loss:
layers: "0::4"
curriculum: torchtune.modules.early_exit_loss.RotationalEarlyExitCurriculum
scale_fn: torchtune.modules.early_exit_loss.sum_l_loss_scale
scale: 1.0

# Layer Dropout
layer_dropout:
prob: 0.2
layers: ":"
layers_scale: "exp"
disable_on_eval: True
Loading
Loading