From 7bfceafd5e1db984d17ecd268730b256ca5bf64d Mon Sep 17 00:00:00 2001 From: Devin Zhou Date: Tue, 13 Jun 2023 14:31:10 -0700 Subject: [PATCH] Enable Class Balancing for Model Train Sampler Summary: This diff enables both category and datasets weight balancing at the same time by declaring "WeightedCategoryTrainingSampler" under "SAMPLER_TRAIN" in config file. Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/4995 X-link: https://github.com/facebookresearch/d2go/pull/570 Differential Revision: D46377371 fbshipit-source-id: 8fdb87f6d844ca7d05a1e90c37e47984ea8d40e3 --- detectron2/data/build.py | 96 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 4 deletions(-) diff --git a/detectron2/data/build.py b/detectron2/data/build.py index daa27a89f7..2574a5f20f 100644 --- a/detectron2/data/build.py +++ b/detectron2/data/build.py @@ -1,14 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. import itertools import logging -import numpy as np import operator import pickle +from collections import defaultdict, OrderedDict from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np import torch import torch.utils.data as torchdata -from tabulate import tabulate -from termcolor import colored from detectron2.config import configurable from detectron2.structures import BoxMode @@ -16,9 +16,16 @@ from detectron2.utils.env import seed_all_rng from detectron2.utils.file_io import PathManager from detectron2.utils.logger import _log_api_usage, log_first_n +from tabulate import tabulate +from termcolor import colored from .catalog import DatasetCatalog, MetadataCatalog -from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset +from .common import ( + AspectRatioGroupedDataset, + DatasetFromList, + MapDataset, + ToIterableDataset, +) from .dataset_mapper import DatasetMapper from .detection_utils import check_metadata_consistency from .samplers import ( @@ -28,6 +35,7 @@ TrainingSampler, ) +logger = logging.getLogger(__name__) """ This file contains the default logic to build a dataloader for training or testing. """ @@ -339,6 +347,82 @@ def build_batch_data_loader( ) +def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]: + repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR + assert all(len(tup) == 2 for tup in repeat_factors) + name_to_weight = defaultdict(lambda: 1, dict(repeat_factors)) + # The sampling weights map should only contain datasets in train config + unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN) + assert not unrecognized, f"unrecognized datasets: {unrecognized}" + + logger.info(f"Found repeat factors: {list(name_to_weight.items())}") + + # pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`. + return name_to_weight + + +def _build_weighted_sampler(cfg, enable_category_balance=False): + dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg) + # OrderedDict to guarantee order of values() consistent with repeat factors + dataset_name_to_dicts = OrderedDict( + { + name: get_detection_dataset_dicts( + [name], + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN + if cfg.MODEL.LOAD_PROPOSALS + else None, + ) + for name in cfg.DATASETS.TRAIN + } + ) + # Repeat factor for every sample in the dataset + repeat_factors = [ + [dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname]) + for dsname in cfg.DATASETS.TRAIN + ] + + repeat_factors = list(itertools.chain.from_iterable(repeat_factors)) + + repeat_factors = torch.tensor(repeat_factors) + if enable_category_balance: + """ + 1. Calculate repeat factors using category frequency for each dataset and then merge them. + 2. Element wise dot producting the dataset frequency repeat factors with + the category frequency repeat factors gives the final repeat factors. + """ + category_repeat_factors = [ + RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD + ) + for dataset_dict in dataset_name_to_dicts.values() + ] + # flatten the category repeat factors from all datasets + category_repeat_factors = list( + itertools.chain.from_iterable(category_repeat_factors) + ) + category_repeat_factors = torch.tensor(category_repeat_factors) + repeat_factors = torch.mul(category_repeat_factors, repeat_factors) + repeat_factors = repeat_factors / torch.min(repeat_factors) + logger.info( + "Using WeightedCategoryTrainingSampler with repeat_factors={}".format( + cfg.DATASETS.TRAIN_REPEAT_FACTOR + ) + ) + else: + logger.info( + "Using WeightedTrainingSampler with repeat_factors={}".format( + cfg.DATASETS.TRAIN_REPEAT_FACTOR + ) + ) + + sampler = RepeatFactorTrainingSampler(repeat_factors) + return sampler + + def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): if dataset is None: dataset = get_detection_dataset_dicts( @@ -373,6 +457,10 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): sampler = RandomSubsetTrainingSampler( len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO ) + elif sampler_name == "WeightedTrainingSampler": + sampler = _build_weighted_sampler(cfg) + elif sampler_name == "WeightedCategoryTrainingSampler": + sampler = _build_weighted_sampler(cfg, enable_category_balance=True) else: raise ValueError("Unknown training sampler: {}".format(sampler_name))