From b457ea910efafad82f5bac93ea2ca9ee93650511 Mon Sep 17 00:00:00 2001 From: Devin Zhou Date: Tue, 13 Jun 2023 11:04:17 -0700 Subject: [PATCH] Enable Class Balancing for Model Train Sampler Summary: X-link: https://github.com/facebookresearch/detectron2/pull/4995 Pull Request resolved: https://github.com/facebookresearch/d2go/pull/570 Differential Revision: D46377371 fbshipit-source-id: c547997c71152d75011cbe94a0c762d20aa33669 --- d2go/data/build.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/d2go/data/build.py b/d2go/data/build.py index f5606f31..682d01a5 100644 --- a/d2go/data/build.py +++ b/d2go/data/build.py @@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]: return name_to_weight -def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): +def build_weighted_detection_train_loader( + cfg: CfgNode, mapper=None, enable_category_balance=False +): dataset_repeat_factors = get_train_datasets_repeat_factors(cfg) # OrderedDict to guarantee order of values() consistent with repeat factors dataset_name_to_dicts = OrderedDict( @@ -103,7 +105,27 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): cfg.DATASETS.TRAIN_REPEAT_FACTOR ) ) - sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors)) + repeat_factors = torch.tensor(repeat_factors) + if enable_category_balance: + """ + 1. Calculate repeat factors using category frequency for each dataset. + 2. Element wise dot producting the dataset frequency repeat factors with + the category frequency repeat factors gives the final repeat factors. + """ + category_repeat_factors = [ + RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD + ) + for dataset_dict in dataset_name_to_dicts.values() + ] + # flatten the category repeat factors from all datasets + category_repeat_factors = list( + itertools.chain.from_iterable(category_repeat_factors) + ) + repeat_factors = torch.mul(category_repeat_factors, repeat_factors) + repeat_factors = repeat_factors / torch.min(repeat_factors) + + sampler = RepeatFactorTrainingSampler(repeat_factors) return build_batch_data_loader( dataset, @@ -149,7 +171,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work @fb_overwritable() def build_mapped_train_loader(cfg, mapper): if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler": + # balancing only datasets frequencies data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper) + elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler": + # balancing both datasets and its categories + data_loader = build_weighted_detection_train_loader( + cfg, mapper=mapper, enable_category_balance=True + ) else: data_loader = build_detection_train_loader(cfg, mapper=mapper) return data_loader