Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add group_by_length optional feature #398

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt
from lightning.fabric.strategies import DeepSpeedStrategy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


instruction_tuning = True
Expand Down Expand Up @@ -113,15 +115,18 @@ def train(
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
group_by_length: bool = False,
) -> None:
"""The training loop.

Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0

for iter_num in range(max_iters):

loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length)
for iter_num, (input_ids, targets) in enumerate(loader):
if iter_num >= max_iters:
break
if step_count <= warmup_iters:
# linear warmup
lr = learning_rate * step_count / warmup_iters
Expand Down Expand Up @@ -223,6 +228,46 @@ def pad_right(x, pad_id):
return x, y


class InstructionDataset(Dataset):
def __init__(self, data: list):
self._data = data

def __len__(self):
return len(self._data)

def __getitem__(self, i: int):
input_ids = self._data[i]["input_ids"].type(torch.int64)
labels = self._data[i]["labels"].type(torch.int64)
return input_ids, labels


def get_dataloader(
fabric: L.Fabric,
data: torch.Tensor,
micro_batch_size: int,
group_by_length: bool,
):
from length_grouped_sampler import LengthGroupedSampler

def collate_fn(batch):
x, y = zip(*batch)
batch_x = pad_sequence(x, batch_first=True)
batch_y = pad_sequence(y, batch_first=True, padding_value=-1)
return batch_x, batch_y

dataset = InstructionDataset(data)
sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None
loader = DataLoader(
dataset,
batch_size=micro_batch_size,
shuffle=(sampler is None),
sampler=sampler,
collate_fn=collate_fn,
pin_memory=True,
)
return fabric.setup_dataloaders(loader)


def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
Expand Down
49 changes: 47 additions & 2 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt
from lightning.fabric.strategies import DeepSpeedStrategy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


eval_interval = 600
Expand Down Expand Up @@ -119,15 +121,18 @@ def train(
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
group_by_length: bool = False,
) -> None:
"""The training loop.

Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0

for iter_num in range(max_iters):

loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length)
for iter_num, (input_ids, targets) in enumerate(loader):
if iter_num >= max_iters:
break
if step_count <= warmup_iters:
# linear warmup
lr = learning_rate * step_count / warmup_iters
Expand Down Expand Up @@ -227,6 +232,46 @@ def pad_right(x, pad_id):
return x, y


class InstructionDataset(Dataset):
def __init__(self, data: list):
self._data = data

def __len__(self):
return len(self._data)

def __getitem__(self, i: int):
input_ids = self._data[i]["input_ids"].type(torch.int64)
labels = self._data[i]["labels"].type(torch.int64)
return input_ids, labels


def get_dataloader(
fabric: L.Fabric,
data: torch.Tensor,
micro_batch_size: int,
group_by_length: bool,
):
from length_grouped_sampler import LengthGroupedSampler

def collate_fn(batch):
x, y = zip(*batch)
batch_x = pad_sequence(x, batch_first=True)
batch_y = pad_sequence(y, batch_first=True, padding_value=-1)
return batch_x, batch_y

dataset = InstructionDataset(data)
sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None
loader = DataLoader(
dataset,
batch_size=micro_batch_size,
shuffle=(sampler is None),
sampler=sampler,
collate_fn=collate_fn,
pin_memory=True,
)
return fabric.setup_dataloaders(loader)


def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
Expand Down
49 changes: 47 additions & 2 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from lit_llama.tokenizer import Tokenizer
from lit_llama.utils import save_model_checkpoint
from scripts.prepare_alpaca import generate_prompt
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


instruction_tuning = True
Expand Down Expand Up @@ -95,6 +97,7 @@ def train(
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
group_by_length: bool = False,
) -> None:
"""The training loop.

Expand All @@ -103,8 +106,10 @@ def train(
step_count = 0
model.train()

for iter_num in range(max_iters):

loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length)
for iter_num, (input_ids, targets) in enumerate(loader):
if iter_num >= max_iters:
break
is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0

if step_count <= warmup_iters:
Expand Down Expand Up @@ -208,6 +213,46 @@ def pad_right(x, pad_id):
return x, y


class InstructionDataset(Dataset):
def __init__(self, data: list):
self._data = data

def __len__(self):
return len(self._data)

def __getitem__(self, i: int):
input_ids = self._data[i]["input_ids"].type(torch.int64)
labels = self._data[i]["labels"].type(torch.int64)
return input_ids, labels


def get_dataloader(
fabric: L.Fabric,
data: torch.Tensor,
micro_batch_size: int,
group_by_length: bool,
):
from length_grouped_sampler import LengthGroupedSampler

def collate_fn(batch):
x, y = zip(*batch)
batch_x = pad_sequence(x, batch_first=True)
batch_y = pad_sequence(y, batch_first=True, padding_value=-1)
return batch_x, batch_y

dataset = InstructionDataset(data)
sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None
loader = DataLoader(
dataset,
batch_size=micro_batch_size,
shuffle=(sampler is None),
sampler=sampler,
collate_fn=collate_fn,
pin_memory=True,
)
return fabric.setup_dataloaders(loader)


def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
Expand Down
101 changes: 101 additions & 0 deletions finetune/length_grouped_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Derived from https://github.com/huggingface/transformers
# ------------------------------------------------------------------------------------------
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------

from typing import Optional, List
import logging
import torch
from torch.utils.data import Dataset, Sampler

logger = logging.get_logger(__name__)


def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
"""
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
lengths. To do this, the indices are:

- randomly permuted
- grouped in mega-batches of size `mega_batch_mult * batch_size`
- sorted by length in each mega-batch

The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
if mega_batch_mult is None:
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
# Just in case, for tiny datasets
if mega_batch_mult == 0:
mega_batch_mult = 1

# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = mega_batch_mult * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]

# The rest is to get the biggest batch first.
# Since each megabatch is sorted by descending length, the longest element is the first
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
# Switch to put the longest element in first position
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]

return [i for megabatch in megabatches for i in megabatch]


class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""

def __init__(
self,
batch_size: int,
dataset: Optional[Dataset] = None,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
generator=None,
):
if dataset is None and lengths is None:
raise ValueError("One of dataset and lengths must be provided.")

self.batch_size = batch_size
if lengths is None:
model_input_name = model_input_name if model_input_name is not None else "input_ids"
if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]:
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
"If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
)
lengths = lengths.tolist()

self.lengths = lengths
self.generator = generator

def __len__(self):
return len(self.lengths)

def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
return iter(indices)
Loading