Skip to content

Commit

Permalink
per-micro-batch input loader (#5635)
Browse files Browse the repository at this point in the history
* per-micro-batch input loader

* per-micro-batch input loader

set arg default val

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* apply per-microbatch-loader to only GPT

* update docstring on micro-batch input loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed the default arg val

* fix batch size to 1 at log stat registration

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update container for CI

Signed-off-by: ericharper <complex451@gmail.com>

* update container in jenkinsfile

Signed-off-by: ericharper <complex451@gmail.com>

* update container for CI

Signed-off-by: ericharper <complex451@gmail.com>

fix merge conflict

* revert Jenkinsfile

* Revert "revert Jenkinsfile"

This reverts commit d23b775.

* Update nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

* add GradScaler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ericharper <complex451@gmail.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
  • Loading branch information
4 people authored Feb 9, 2023
1 parent dfdd8f0 commit 3b87f88
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 110 deletions.
4 changes: 2 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pipeline {
agent {
docker {
image 'nvcr.io/nvidia/pytorch:23.01-py3'
image 'nemo_containers:23.01_apex_c3d575f2478cd379b3c2d81f41edde39791b5d92'
args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache:/root/.cache --shm-size=8g'
}
}
Expand Down Expand Up @@ -4509,4 +4509,4 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
cleanWs()
}
}
}
}
136 changes: 97 additions & 39 deletions nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,84 @@

"""Dataloaders."""

import abc
from typing import Optional

import torch

from nemo.utils import logging


class MegatronPretrainingSampler:
class BaseMegatronSampler:
def __init__(
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
):
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
) -> None:
# Sanity checks.
if total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(total_samples))
if consumed_samples >= total_samples:
raise RuntimeError("no samples left to consume: {}, {}".format(consumed_samples, total_samples))
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
data_parallel_rank, data_parallel_size
)
)
if global_batch_size is not None:
if global_batch_size % (micro_batch_size * data_parallel_size) != 0:
raise RuntimeError(
f"`global_batch_size` ({global_batch_size}) is not divisible by "
f"`micro_batch_size ({micro_batch_size}) x data_parallel_size "
f"({data_parallel_size})`"
)
if pad_samples_to_global_batch_size and global_batch_size is None:
raise RuntimeError(
f"`pad_samples_to_global_batch_size` can be `True` only when "
f"`global_batch_size` is set to an integer value"
)

# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
self.global_batch_size = global_batch_size
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size

logging.info(
f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} and consumed_samples: {consumed_samples}'
)

# Sanity checks.
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, 'no samples left to consume: {}, {}'.format(
self.consumed_samples, self.total_samples
)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, (
'data_parallel_rank should be smaller than data size: {}, '
'{}'.format(self.data_parallel_rank, data_parallel_size)
)

def __len__(self):
return (self.total_samples - self.consumed_samples - 1) // self.micro_batch_times_data_parallel_size + 1
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1

@abc.abstractmethod
def __iter__(self):
...


class MegatronPretrainingSampler(BaseMegatronSampler):
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
Expand All @@ -68,32 +109,45 @@ def __iter__(self):

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

if self.pad_samples_to_global_batch_size:
for i in range(
self.data_parallel_rank, self.global_batch_size, self.micro_batch_times_data_parallel_size
):
indices = [batch[j] for j in range(i, max(len(batch), i + self.micro_batch_size))]
num_pad = self.micro_batch_size - len(indices)
indices = indices + [-1] * num_pad
yield indices
else:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, (
'data_parallel_rank should be smaller than data size: {}, '
'{}'.format(self.data_parallel_rank, data_parallel_size)
class MegatronPretrainingRandomSampler(BaseMegatronSampler):
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
) -> None:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
global_batch_size=global_batch_size,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)

def __len__(self):
return self.total_samples
assert (
pad_samples_to_global_batch_size == False
), "`MegatronPretrainingRandomSampler` does not support sample padding"
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size

def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
Expand All @@ -119,3 +173,7 @@ def __iter__(self):
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
yield batch
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def configure_gradient_clipping(self, *args, **kwargs):
parameters = self._get_parameters()
grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)

self.log('grad_norm', grad_norm, rank_zero_only=True)
self.log('grad_norm', grad_norm, rank_zero_only=True, batch_size=1)

def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks.
Expand Down
Loading

0 comments on commit 3b87f88

Please sign in to comment.