Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Early Exit Loss and/or Layer Dropout #1076

Merged
merged 90 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
97cb9a8
start of layer dropout implementation
mostafaelhoushi Jun 9, 2024
4a25c5b
have different dropouts at different layers
mostafaelhoushi Jun 9, 2024
ac8ad0b
add option to specify which layers to apply dropout
mostafaelhoushi Jun 9, 2024
ae61c85
start early exit loss
mostafaelhoushi Jun 10, 2024
735d2a8
parallelize processing of early exit losses
mostafaelhoushi Jun 19, 2024
be912a6
use absolute imports
mostafaelhoushi Jun 19, 2024
0686dd2
remove unnecessary sync
mostafaelhoushi Jun 19, 2024
4e4783f
move early exit loss to separate file and add layers as arg
mostafaelhoushi Jun 19, 2024
268813e
perform loss scaling every iteration
mostafaelhoushi Jun 19, 2024
ccb4a50
return hidden states as an output rather than storing
mostafaelhoushi Jun 19, 2024
ff7d157
ensure last layer is always included
mostafaelhoushi Jun 20, 2024
5a23811
return either last logits or hidden states
mostafaelhoushi Jun 20, 2024
e11aeba
fix scaling layers
mostafaelhoushi Jun 20, 2024
f9e164f
rotational early exit curriculum
mostafaelhoushi Jun 20, 2024
069b661
set early exit params from cli
mostafaelhoushi Jul 22, 2024
954d097
ensure last layer loss is always calculated
mostafaelhoushi Jul 22, 2024
5789745
implement gradual early exit
mostafaelhoushi Jul 22, 2024
c3534e6
get streaming to work
mostafaelhoushi Jul 22, 2024
d1c6963
Merge branch 'main' into layerskip
mostafaelhoushi Nov 12, 2024
7849130
add separate recipe for early exit
mostafaelhoushi Nov 13, 2024
df89c4f
port early exit loss code from PR
mostafaelhoushi Nov 13, 2024
6cedb19
convert boolean array to indices
mostafaelhoushi Nov 17, 2024
a83da5a
decide on hidden outputs by member variable not forward pass
mostafaelhoushi Nov 17, 2024
2a8791d
add early exit recipe config
mostafaelhoushi Nov 17, 2024
a326937
refactor unembedding
mostafaelhoushi Nov 17, 2024
8ba6ab4
got early exit loss to work
mostafaelhoushi Nov 18, 2024
681e7ca
add TopV2 instruction set
mostafaelhoushi Nov 19, 2024
119ac7d
ensure all early exit loss params from cfg file are passed to code
mostafaelhoushi Nov 19, 2024
3ec9d23
fix gradual early exit
mostafaelhoushi Nov 19, 2024
04a590f
add test cases for early exit loss
mostafaelhoushi Nov 19, 2024
9b5c96a
add more assertions for rotational early exit
mostafaelhoushi Nov 19, 2024
3319ab0
test to follow training code
mostafaelhoushi Nov 19, 2024
619b3eb
fix curriculum update
mostafaelhoushi Nov 20, 2024
d376ddd
update recipe
mostafaelhoushi Nov 21, 2024
ff3977b
reset changes to data loading
mostafaelhoushi Nov 21, 2024
75b2e01
code cleanup
mostafaelhoushi Nov 23, 2024
33a95f5
rename early_exit to early_exit_loss
mostafaelhoushi Nov 23, 2024
5d7e903
address some early exit TODOs
mostafaelhoushi Nov 23, 2024
87f2ee0
get layer dropout to work
mostafaelhoushi Nov 23, 2024
1de0c2a
clean up early exit curriculum
mostafaelhoushi Nov 24, 2024
2b0cdd1
enable grad curriculum for subset of layers + clear hidden_states at …
mostafaelhoushi Nov 24, 2024
7973459
add docstring for slice_str_to_array
mostafaelhoushi Nov 24, 2024
baed8a9
support commas and add assertion statements
mostafaelhoushi Nov 24, 2024
27f6b56
add test cases for slice_to_str_array
mostafaelhoushi Nov 24, 2024
63e7c5b
add copyright header
mostafaelhoushi Nov 24, 2024
638056b
support single index
mostafaelhoushi Nov 24, 2024
a20b07c
add new line at end of file
mostafaelhoushi Nov 24, 2024
64210e6
Merge branch 'main' into layerskip
mostafaelhoushi Nov 24, 2024
98897a8
add layer dropout test cases
mostafaelhoushi Nov 24, 2024
2cc94cc
rename apply_layer_dropout to prepare_layer_dropout
mostafaelhoushi Nov 24, 2024
f4f8e02
add test cases for get_scale
mostafaelhoushi Nov 24, 2024
fed955e
cleanup get_scale + re-write mathematically equivalent + ensure max s…
mostafaelhoushi Nov 24, 2024
ca7d8da
test layer_dropout
mostafaelhoushi Nov 24, 2024
0146764
start adding early exit loss and layer dropout to docstring
mostafaelhoushi Nov 24, 2024
f599eca
fix and update code and test cases to handle updating last layer sepa…
mostafaelhoushi Nov 24, 2024
2437092
change match to if-else for CI
mostafaelhoushi Nov 24, 2024
ad090af
add assertion on type of loss fn for early exit loss
mostafaelhoushi Nov 25, 2024
cec8cd4
add docstring and slightly change attribute of layer_dropout and earl…
mostafaelhoushi Nov 25, 2024
b69f2f3
refactor layer_dropout and add test cases on wrapper
mostafaelhoushi Nov 25, 2024
a21cbd3
add TODO comment
mostafaelhoushi Nov 25, 2024
eb37cb6
fix error in checking if early exit loss is enabled
mostafaelhoushi Nov 25, 2024
2e3f502
change recipe defaults of dataset and layer_drop probability
mostafaelhoushi Nov 26, 2024
66a41b2
add detailed docstring to training script
mostafaelhoushi Nov 26, 2024
345a0a3
ensure we set last layer early exit enable correctly
mostafaelhoushi Nov 26, 2024
20c618c
ensure uniform early exit loss works
mostafaelhoushi Nov 26, 2024
f0e8d7f
add documentation to .yaml file and update doc in .py
mostafaelhoushi Nov 26, 2024
b03cb57
remove commented lines
mostafaelhoushi Nov 27, 2024
199b8dd
remove check on PyTorch version since we assume latest stable PyTorch
mostafaelhoushi Nov 27, 2024
6a2d79b
load curriculum step when resuming
mostafaelhoushi Nov 27, 2024
e5534ea
repeat arguments in derived classes
mostafaelhoushi Nov 27, 2024
d270d1f
rename percent_scale to fraction_scale and change its implementation
mostafaelhoushi Nov 27, 2024
e51419c
fixes to docstrings and config examples
mostafaelhoushi Dec 1, 2024
40b7987
check if cfg_early_exit_loss has curriculum
mostafaelhoushi Dec 1, 2024
0c18595
add comment to explain when has no effect
mostafaelhoushi Dec 1, 2024
3e68696
organize early exit loss tests into classes
mostafaelhoushi Dec 1, 2024
418951b
fix typo
mostafaelhoushi Dec 1, 2024
e5a53f9
test all loss scale types
mostafaelhoushi Dec 1, 2024
3567a24
use variable number of subset layers
mostafaelhoushi Dec 1, 2024
ae2108d
ensure get_scale returns values between 0 and 1
mostafaelhoushi Dec 1, 2024
71707de
add test cases for sigmoid
mostafaelhoushi Dec 2, 2024
78aff5a
make prepare_layer_dropout apply on a list of layers rather than a model
mostafaelhoushi Dec 2, 2024
0fb373b
Only add `optional` in docstring when argument is optional
mostafaelhoushi Dec 4, 2024
b66e23b
add Dropout class and prepare_layer_dropout APIs to docs
mostafaelhoushi Dec 4, 2024
cd8be64
add empty line between function description and Args
mostafaelhoushi Dec 4, 2024
2675b4c
remove assert statement as we added the check in testing
mostafaelhoushi Dec 4, 2024
00d8efa
change loss scale from enum to function
mostafaelhoushi Dec 5, 2024
78b8996
change curriculum from enum to function
mostafaelhoushi Dec 5, 2024
ed33ba9
rename scale_type to scale_fn
mostafaelhoushi Dec 6, 2024
c7f02de
change default
mostafaelhoushi Dec 6, 2024
69f840c
update docstring
mostafaelhoushi Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,14 +516,28 @@ def train(self) -> None:
input_pos.to(self._device) if input_pos is not None else None
)

logits = self._model(tokens, mask=mask, input_pos=input_pos)
logits = self._model(tokens, mask=mask, input_pos=input_pos, output_hidden_states=True)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)

# Compute early exit loss
if self._model.output_hidden_states:
# TODO: calculate early_logits in one shot:
# logits_early = self._model.output(self._model.norm(torch.stack(tuple(self._model.output_hidden_states.values()))))
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved
for layer_id, hidden_state in self._model.output_hidden_states.items():
h_early = self._model.norm(hidden_state)
logits_early = self._model.output(h_early)
# Shift so that tokens < n predict n
logits_early = logits_early[..., :-1, :].contiguous()
logits_early = logits_early.transpose(1, 2)
# Compute early loss
loss_early = self._loss_fn(logits_early, labels)
loss += 0.1 / len(self._model.layers) * loss_early
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved

loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
Expand Down
3 changes: 3 additions & 0 deletions torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from .attention import CausalSelfAttention # noqa
from .layer_dropout import LayerDropout, create_layer_dropout_modules # noqa
from .common_utils import reparametrize_as_dtype_state_dict_post_hook
from .feed_forward import FeedForward # noqa
from .kv_cache import KVCache # noqa
Expand All @@ -24,4 +25,6 @@
"TransformerDecoderLayer",
"TransformerClassifier",
"reparametrize_as_dtype_state_dict_post_hook",
"LayerDropout",
"create_layer_dropout_modules",
]
27 changes: 27 additions & 0 deletions torchtune/modules/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,30 @@ def reparametrize_as_dtype_state_dict_post_hook(
state_dict[k] = v.to(dtype)
if offload_to_cpu:
state_dict[k] = state_dict[k].cpu()

def slice_str_to_array(slice_str, length):
# Parse the slice string
parts = slice_str.split(':')
start, end, step = None, None, None

if len(parts) == 1 and parts[0] != '':
start = int(parts[0])
elif len(parts) == 2:
start = int(parts[0]) if parts[0] != '' else None
end = int(parts[1]) if parts[1] != '' else None
elif len(parts) == 3:
start = int(parts[0]) if parts[0] != '' else None
end = int(parts[1]) if parts[1] != '' else None
step = int(parts[2]) if parts[2] != '' else None

# Create a boolean array based on the slice
result = [False] * length
slice_indices = range(start if start is not None else 0,
end if end is not None else length,
step if step is not None else 1)

for i in slice_indices:
if 0 <= i < length:
result[i] = True

return result
85 changes: 85 additions & 0 deletions torchtune/modules/layer_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Callable, Optional
import math
import torch

from .common_utils import slice_str_to_array
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved

class LayerDropout(torch.nn.Module):
def __init__(self, prob=0.0, dim=0, disable_on_eval=True, seed=None):
super().__init__()
self.prob: float = prob
self.dim = dim
self.disable_on_eval: bool = disable_on_eval
self.generator = torch.Generator(device="cpu")
self.inferred: float = None

if seed is not None:
self.generator.manual_seed(seed)

def forward(self, function: Callable, input: torch.Tensor, *args, **kwargs):
n = input.shape[self.dim]

if self.prob == 0 or (self.disable_on_eval and self.training is False):
self.inferred = 1.0
return function(input, *args, **kwargs)

skip = torch.bernoulli(torch.Tensor((n) * [self.prob]), generator=self.generator).to(input.device).to(input.dtype)
self.inferred = 1 - torch.mean(skip)
ind_selected = (skip == 0).nonzero().squeeze().to(input.device)
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved

if ind_selected.numel() > 0:
x_selected = torch.index_select(input, self.dim, ind_selected)
out_selected = function(x_selected, *args, **kwargs)

out = input.clone()
assert self.dim == 0, "Currently only supporting dropping elements along the 0th dimension"
if ind_selected.numel() > 0:
out[ind_selected] = out_selected
return out

class ScaleType(str, Enum):
UNIFORM = "uniform"
EXP = "exp"
LINEAR = "linear"
LOG = "log"
SIN = "sin"
SIGMOID = "sigmoid"
STEP = "step"

def get_scale(scale_type: ScaleType, scale_period: int, val: int):
if scale_period == 0:
return 1

# all the equations below aim to make scale = 0 when val=0, and scale = 1 when val=scale_period
return {
ScaleType.UNIFORM: 1,
ScaleType.EXP: math.exp(val * math.log(2) / scale_period) - 1,
ScaleType.LINEAR: val / scale_period,
ScaleType.LOG: math.log(val + 1) / math.log(scale_period + 1),
ScaleType.SIN: math.sin(0.5 * math.pi * val / scale_period),
ScaleType.SIGMOID: 1 / (1 + math.exp(-10 * (val / scale_period - 0.5))),
}[scale_type]

def create_layer_dropout_modules(num_layers: int, prob_max: float= 0.0, prob_layer_scale: ScaleType = ScaleType.EXP, layers_str: Optional[str] = None, disable_on_eval: bool = True):
layer_dropouts = torch.nn.ModuleList()
has_dropout = slice_str_to_array(layers_str, num_layers) if layers_str else [True] * num_layers
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved

for layer_id in range(num_layers):
prob = prob_max * get_scale(
scale_type = prob_layer_scale,
scale_period = num_layers - 1,
val = layer_id,
) if has_dropout[layer_id] else 0.0
assert prob >= 0.0 and prob <= prob_max, f"prob={prob} should be between 0 and {prob_max}"
# We would like each layer to have a different seed, so that we don't have the same samples skipped across layers. Hence, we use the layer_id as a seed for each layer's dropout.
layer_dropout = LayerDropout(prob, disable_on_eval=disable_on_eval, seed=layer_id)
layer_dropouts.append(layer_dropout)

return layer_dropouts
22 changes: 19 additions & 3 deletions torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Optional
from collections import OrderedDict
from typing import List, Optional, Union

import torch
from torch import nn, Tensor

from torchtune.modules import CausalSelfAttention, KVCache
from torchtune.modules import LayerDropout, create_layer_dropout_modules


class TransformerDecoderLayer(nn.Module):
Expand Down Expand Up @@ -121,6 +123,8 @@ class TransformerDecoder(nn.Module):
before final MLP.
output (nn.Linear): Callable that applies a linear transformation to the output of
the decoder.
layer_dropout_prob (float): Probability of skipping samples in the transformer
layer.

Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
Expand All @@ -138,6 +142,9 @@ def __init__(
head_dim: int,
norm: nn.Module,
output: nn.Linear,
layer_dropout_prob: float = 0.5,
layer_dropout_prob_layer_scale: str = "exp",
layer_dropout_str: str = ":",
) -> None:
super().__init__()

Expand All @@ -150,6 +157,9 @@ def __init__(
self.head_dim = head_dim
self.causal_mask = None

self.layer_dropouts = create_layer_dropout_modules(num_layers, layer_dropout_prob, layer_dropout_prob_layer_scale, layer_dropout_str)
self.output_hidden_states = OrderedDict() # TODO: use tensordict?
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved

def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:
"""Setup key value caches for attention calculation.

Expand Down Expand Up @@ -188,6 +198,7 @@ def forward(
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
output_hidden_states: Union[bool, List[bool]] = False,
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved
) -> Tensor:
"""
Args:
Expand Down Expand Up @@ -227,6 +238,9 @@ def forward(
# shape: [b, s, d]
h = self.tok_embeddings(tokens)

if isinstance(output_hidden_states, bool):
output_hidden_states = [output_hidden_states] * len(self.layers)

if self.causal_mask is not None:
if input_pos is None:
raise ValueError(
Expand All @@ -240,9 +254,11 @@ def forward(
# in most cases input_pos_len should be 1
mask = self.causal_mask[None, input_pos]

for layer in self.layers:
for i, layer in enumerate(self.layers):
# shape: [b, s, d]
h = layer(h, mask=mask, input_pos=input_pos)
h = self.layer_dropouts[i](layer, h, mask=mask, input_pos=input_pos)
mostafaelhoushi marked this conversation as resolved.
Show resolved Hide resolved
if output_hidden_states[i]:
self.output_hidden_states[i] = h

# shape: [b, s, d]
h = self.norm(h)
Expand Down
Loading