Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RepeatAugSampler #1678

Merged
merged 4 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mmocr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .ocr_dataset import OCRDataset
from .recog_lmdb_dataset import RecogLMDBDataset
from .recog_text_dataset import RecogTextDataset
from .samplers import * # NOQA
from .transforms import * # NOQA
from .wildreceipt_dataset import WildReceiptDataset

Expand Down
4 changes: 4 additions & 0 deletions mmocr/datasets/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .repeat_aug import RepeatAugSampler

__all__ = ['RepeatAugSampler']
101 changes: 101 additions & 0 deletions mmocr/datasets/samplers/repeat_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import math
from typing import Iterator, Optional, Sized

import torch
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
from torch.utils.data import Sampler

from mmocr.registry import DATA_SAMPLERS


@DATA_SAMPLERS.register_module()
class RepeatAugSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset for
distributed, with repeated augmentation. It ensures that different each
augmented version of a sample will be visible to a different process (GPU).
Heavily based on torch.utils.data.DistributedSampler.

This sampler was taken from
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
Used in
Copyright (c) 2015-present, Facebook, Inc.

Args:
dataset (Sized): The dataset.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
num_repeats (int): The repeat times of every sample. Defaults to 3.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
"""

def __init__(self,
dataset: Sized,
shuffle: bool = True,
num_repeats: int = 3,
seed: Optional[int] = None):
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size

self.dataset = dataset
self.shuffle = shuffle
if not self.shuffle and is_main_process():
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.warning('The RepeatAugSampler always picks a '
'fixed part of data if `shuffle=False`.')

if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.num_repeats = num_repeats

# The number of repeated samples in the rank
self.num_samples = math.ceil(
len(self.dataset) * num_repeats / world_size)
# The total number of repeated samples in all ranks.
self.total_size = self.num_samples * world_size
# The number of selected samples in the rank
self.num_selected_samples = math.ceil(len(self.dataset) / world_size)

def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))

# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices = [x for x in indices for _ in range(self.num_repeats)]
# add extra samples to make it evenly divisible
indices = (indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
assert len(indices) == self.total_size

# subsample per rank
indices = indices[self.rank:self.total_size:self.world_size]
assert len(indices) == self.num_samples

# return up to num selected samples
return iter(indices[:self.num_selected_samples])

def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_selected_samples

def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.

When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.

Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
98 changes: 98 additions & 0 deletions tests/test_datasets/test_samplers/test_repeat_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) OpenMMLab. All rights reserved.

import math
from unittest import TestCase
from unittest.mock import patch

import torch
from mmengine.logging import MMLogger

from mmocr.datasets import RepeatAugSampler

file = 'mmocr.datasets.samplers.repeat_aug.'


class MockDist:

def __init__(self, dist_info=(0, 1), seed=7):
self.dist_info = dist_info
self.seed = seed

def get_dist_info(self):
return self.dist_info

def sync_random_seed(self):
return self.seed

def is_main_process(self):
return self.dist_info[0] == 0


class TestRepeatAugSampler(TestCase):

def setUp(self):
self.data_length = 100
self.dataset = list(range(self.data_length))

@patch(file + 'get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0)
self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_samples, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples, self.data_length)
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(list(sampler), indices[:self.data_length])

logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
sampler = RepeatAugSampler(self.dataset, shuffle=False)
self.assertIn('always picks a fixed part', log.output[0])

@patch(file + 'get_dist_info', return_value=(2, 3))
@patch(file + 'is_main_process', return_value=False)
def test_dist(self, mock1, mock2):
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2)
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(sampler.total_size, self.data_length * 3)
self.assertEqual(sampler.num_selected_samples,
math.ceil(self.data_length / 3))
self.assertEqual(len(sampler), sampler.num_selected_samples)
indices = [x for x in range(self.data_length) for _ in range(3)]
self.assertEqual(
list(sampler), indices[2::3][:sampler.num_selected_samples])

logger = MMLogger.get_current_instance()
with patch.object(logger, 'warning') as mock_log:
sampler = RepeatAugSampler(self.dataset, shuffle=False)
mock_log.assert_not_called()

@patch(file + 'get_dist_info', return_value=(0, 1))
@patch(file + 'sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = RepeatAugSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7)

# test random seed
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(10)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = [x for x in indices
for _ in range(3)][:sampler.num_selected_samples]
self.assertEqual(list(sampler), indices)

sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(42 + 10)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = [x for x in indices
for _ in range(3)][:sampler.num_selected_samples]
self.assertEqual(list(sampler), indices)