Skip to content

Commit

Permalink
Enable Class Balancing for Model Train Sampler
Browse files Browse the repository at this point in the history
Summary:
This diff enables both category and datasets weight balancing at the same time by declaring "WeightedCategoryTrainingSampler" under "SAMPLER_TRAIN" in config file.

Pull Request resolved: facebookresearch#4995

X-link: facebookresearch/d2go#570

Differential Revision: D46377371

fbshipit-source-id: b4cef667ea7fd1955bb189ccca81d4ec99f8f52b
  • Loading branch information
dxzhou2023 authored and facebook-github-bot committed Jun 14, 2023
1 parent 94113be commit 530566b
Showing 1 changed file with 92 additions and 4 deletions.
96 changes: 92 additions & 4 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import logging
import numpy as np
import operator
import pickle
from collections import defaultdict, OrderedDict
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
import torch.utils.data as torchdata
from tabulate import tabulate
from termcolor import colored

from detectron2.config import configurable
from detectron2.structures import BoxMode
from detectron2.utils.comm import get_world_size
from detectron2.utils.env import seed_all_rng
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import _log_api_usage, log_first_n
from tabulate import tabulate
from termcolor import colored

from .catalog import DatasetCatalog, MetadataCatalog
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
from .common import (
AspectRatioGroupedDataset,
DatasetFromList,
MapDataset,
ToIterableDataset,
)
from .dataset_mapper import DatasetMapper
from .detection_utils import check_metadata_consistency
from .samplers import (
Expand All @@ -28,6 +35,7 @@
TrainingSampler,
)

logger = logging.getLogger(__name__)
"""
This file contains the default logic to build a dataloader for training or testing.
"""
Expand Down Expand Up @@ -339,6 +347,82 @@ def build_batch_data_loader(
)


def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]:
repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR
assert all(len(tup) == 2 for tup in repeat_factors)
name_to_weight = defaultdict(lambda: 1, dict(repeat_factors))
# The sampling weights map should only contain datasets in train config
unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN)
assert not unrecognized, f"unrecognized datasets: {unrecognized}"

logger.info(f"Found repeat factors: {list(name_to_weight.items())}")

# pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`.
return name_to_weight


def _build_weighted_sampler(cfg, enable_category_balance=False):
dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict(
{
name: get_detection_dataset_dicts(
[name],
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN
if cfg.MODEL.LOAD_PROPOSALS
else None,
)
for name in cfg.DATASETS.TRAIN
}
)
# Repeat factor for every sample in the dataset
repeat_factors = [
[dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname])
for dsname in cfg.DATASETS.TRAIN
]

repeat_factors = list(itertools.chain.from_iterable(repeat_factors))

repeat_factors = torch.tensor(repeat_factors)
if enable_category_balance:
"""
1. Calculate repeat factors using category frequency for each dataset and then merge them.
2. Element wise dot producting the dataset frequency repeat factors with
the category frequency repeat factors gives the final repeat factors.
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
)
for dataset_dict in dataset_name_to_dicts.values()
]
# flatten the category repeat factors from all datasets
category_repeat_factors = list(
itertools.chain.from_iterable(category_repeat_factors)
)
category_repeat_factors = torch.tensor(category_repeat_factors)
repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)
logger.info(
"Using WeightedCategoryTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
else:
logger.info(
"Using WeightedTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)

sampler = RepeatFactorTrainingSampler(repeat_factors)
return sampler


def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
if dataset is None:
dataset = get_detection_dataset_dicts(
Expand Down Expand Up @@ -373,6 +457,10 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
sampler = RandomSubsetTrainingSampler(
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
)
elif sampler_name == "WeightedTrainingSampler":
sampler = _build_weighted_sampler(cfg)
elif sampler_name == "WeightedCategoryTrainingSampler":
sampler = _build_weighted_sampler(cfg, enable_category_balance=True)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))

Expand Down

0 comments on commit 530566b

Please sign in to comment.