Skip to content

Commit

Permalink
Enable Class Balancing for Model Train Sampler
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#4995

X-link: facebookresearch/d2go#570

Differential Revision: D46377371

fbshipit-source-id: 47e6f636a8f27cec28e202cf4080dd54c58ccac0
  • Loading branch information
dxzhou2023 authored and facebook-github-bot committed Jun 12, 2023
1 parent 027099d commit df4a040
Showing 1 changed file with 87 additions and 4 deletions.
91 changes: 87 additions & 4 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import logging
import numpy as np
import operator
import pickle
from collections import defaultdict, OrderedDict
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
import torch.utils.data as torchdata
from tabulate import tabulate
from termcolor import colored

from detectron2.config import configurable
from detectron2.structures import BoxMode
from detectron2.utils.comm import get_world_size
from detectron2.utils.env import seed_all_rng
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import _log_api_usage, log_first_n
from tabulate import tabulate
from termcolor import colored

from .catalog import DatasetCatalog, MetadataCatalog
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
from .common import (
AspectRatioGroupedDataset,
DatasetFromList,
MapDataset,
ToIterableDataset,
)
from .dataset_mapper import DatasetMapper
from .detection_utils import check_metadata_consistency
from .samplers import (
Expand All @@ -28,6 +35,8 @@
TrainingSampler,
)

logger = logging.getLogger(__name__)

"""
This file contains the default logic to build a dataloader for training or testing.
"""
Expand Down Expand Up @@ -339,6 +348,76 @@ def build_batch_data_loader(
)


def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]:
repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR
assert all(len(tup) == 2 for tup in repeat_factors)
name_to_weight = defaultdict(lambda: 1, dict(repeat_factors))
# The sampling weights map should only contain datasets in train config
unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN)
assert not unrecognized, f"unrecognized datasets: {unrecognized}"

logger.info(f"Found repeat factors: {list(name_to_weight.items())}")

# pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`.
return name_to_weight


def _build_weighted_sampler(cfg, mapper=None, enable_category_balance=False):
dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict(
{
name: get_detection_dataset_dicts(
[name],
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN
if cfg.MODEL.LOAD_PROPOSALS
else None,
)
for name in cfg.DATASETS.TRAIN
}
)
# Repeat factor for every sample in the dataset
repeat_factors = [
[dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname])
for dsname in cfg.DATASETS.TRAIN
]
repeat_factors = list(itertools.chain.from_iterable(repeat_factors))

dataset_dicts = dataset_name_to_dicts.values()
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
dataset = DatasetFromList(dataset_dicts, copy=False)
if mapper is None:
mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper)

logger.info(
"Using WeightedTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
repeat_factors = torch.tensor(repeat_factors)
if enable_category_balance:
"""
Current repeat factor is already balanced by dataset frequency
Use annotations information in dataset_dicts to compute repeat factors to balance category frequencies.
Element wise dot producting these two list of repeat factors gives the final balanced repeat factors
"""
class_balance_repeat_factors = (
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
)
)
repeat_factors = torch.mul(class_balance_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)

sampler = RepeatFactorTrainingSampler(repeat_factors)
return sampler


def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
if dataset is None:
dataset = get_detection_dataset_dicts(
Expand Down Expand Up @@ -373,6 +452,10 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
sampler = RandomSubsetTrainingSampler(
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
)
elif sampler_name == "WeightedTrainingSampler":
sampler = _build_weighted_sampler(cfg, mapper)
elif sampler_name == "WeightedCategoryTrainingSampler":
sampler = _build_weighted_sampler(cfg, mapper, enable_category_balance=True)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))

Expand Down

0 comments on commit df4a040

Please sign in to comment.