From 99f6b5537dafb6d8c9e01c1afacbbca7bad6929f Mon Sep 17 00:00:00 2001 From: Devin Zhou Date: Mon, 12 Jun 2023 15:27:10 -0700 Subject: [PATCH] Enable Class Balancing for Model Train Sampler Summary: X-link: https://github.com/facebookresearch/detectron2/pull/4995 Pull Request resolved: https://github.com/facebookresearch/d2go/pull/570 Differential Revision: D46377371 fbshipit-source-id: 8e5e8b859a77eb291bfa30e54ce45126e0d2cd60 --- d2go/data/build.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/d2go/data/build.py b/d2go/data/build.py index f5606f31..d2c91d8a 100644 --- a/d2go/data/build.py +++ b/d2go/data/build.py @@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]: return name_to_weight -def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): +def build_weighted_detection_train_loader( + cfg: CfgNode, mapper=None, enable_category_balance=False +): dataset_repeat_factors = get_train_datasets_repeat_factors(cfg) # OrderedDict to guarantee order of values() consistent with repeat factors dataset_name_to_dicts = OrderedDict( @@ -103,7 +105,22 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): cfg.DATASETS.TRAIN_REPEAT_FACTOR ) ) - sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors)) + repeat_factors = torch.tensor(repeat_factors) + if enable_category_balance: + """ + Current repeat factor is already balanced by dataset frequency + Use annotations information in dataset_dicts to compute repeat factors to balance category frequencies. + Element wise dot producting these two list of repeat factors gives the final balanced repeat factors + """ + class_balance_repeat_factors = ( + RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD + ) + ) + repeat_factors = torch.mul(class_balance_repeat_factors, repeat_factors) + repeat_factors = repeat_factors / torch.min(repeat_factors) + + sampler = RepeatFactorTrainingSampler(repeat_factors) return build_batch_data_loader( dataset, @@ -149,7 +166,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work @fb_overwritable() def build_mapped_train_loader(cfg, mapper): if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler": + # balancing only datasets frequencies data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper) + elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler": + # balancing both datasets and its categories + data_loader = build_weighted_detection_train_loader( + cfg, mapper=mapper, enable_class_balance=True + ) else: data_loader = build_detection_train_loader(cfg, mapper=mapper) return data_loader