Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

per-micro-batch input loader #5635

Merged
merged 20 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pipeline {
agent {
docker {
image 'nvcr.io/nvidia/pytorch:23.01-py3'
image 'nemo_containers:23.01_apex_c3d575f2478cd379b3c2d81f41edde39791b5d92'
args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache:/root/.cache --shm-size=8g'
}
}
Expand Down Expand Up @@ -4509,4 +4509,4 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
cleanWs()
}
}
}
}
136 changes: 97 additions & 39 deletions nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,84 @@

"""Dataloaders."""

import abc
from typing import Optional

import torch

from nemo.utils import logging


class MegatronPretrainingSampler:
class BaseMegatronSampler:
ericharper marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
):
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
) -> None:
# Sanity checks.
if total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(total_samples))
if consumed_samples >= total_samples:
raise RuntimeError("no samples left to consume: {}, {}".format(consumed_samples, total_samples))
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
data_parallel_rank, data_parallel_size
)
)
if global_batch_size is not None:
if global_batch_size % (micro_batch_size * data_parallel_size) != 0:
raise RuntimeError(
f"`global_batch_size` ({global_batch_size}) is not divisible by "
f"`micro_batch_size ({micro_batch_size}) x data_parallel_size "
f"({data_parallel_size})`"
)
if pad_samples_to_global_batch_size and global_batch_size is None:
raise RuntimeError(
f"`pad_samples_to_global_batch_size` can be `True` only when "
f"`global_batch_size` is set to an integer value"
)

# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
self.global_batch_size = global_batch_size
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size

logging.info(
f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} and consumed_samples: {consumed_samples}'
)

# Sanity checks.
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, 'no samples left to consume: {}, {}'.format(
self.consumed_samples, self.total_samples
)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, (
'data_parallel_rank should be smaller than data size: {}, '
'{}'.format(self.data_parallel_rank, data_parallel_size)
)

def __len__(self):
return (self.total_samples - self.consumed_samples - 1) // self.micro_batch_times_data_parallel_size + 1
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1

@abc.abstractmethod
def __iter__(self):
...

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.


class MegatronPretrainingSampler(BaseMegatronSampler):
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
Expand All @@ -68,32 +109,45 @@ def __iter__(self):

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

if self.pad_samples_to_global_batch_size:
for i in range(
self.data_parallel_rank, self.global_batch_size, self.micro_batch_times_data_parallel_size
):
indices = [batch[j] for j in range(i, max(len(batch), i + self.micro_batch_size))]
num_pad = self.micro_batch_size - len(indices)
indices = indices + [-1] * num_pad
yield indices
else:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, 'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, (
'data_parallel_rank should be smaller than data size: {}, '
'{}'.format(self.data_parallel_rank, data_parallel_size)
class MegatronPretrainingRandomSampler(BaseMegatronSampler):
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
) -> None:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
global_batch_size=global_batch_size,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)

def __len__(self):
return self.total_samples
assert (
pad_samples_to_global_batch_size == False
), "`MegatronPretrainingRandomSampler` does not support sample padding"
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size

def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
Expand All @@ -119,3 +173,7 @@ def __iter__(self):
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
yield batch
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def configure_gradient_clipping(self, *args, **kwargs):
parameters = self._get_parameters()
grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)

self.log('grad_norm', grad_norm, rank_zero_only=True)
self.log('grad_norm', grad_norm, rank_zero_only=True, batch_size=1)

def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks.
Expand Down
Loading