diff --git a/d2go/data/build.py b/d2go/data/build.py index f5606f31..305e9b60 100644 --- a/d2go/data/build.py +++ b/d2go/data/build.py @@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]: return name_to_weight -def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): +def build_weighted_detection_train_loader( + cfg: CfgNode, mapper=None, enable_category_balance=False +): dataset_repeat_factors = get_train_datasets_repeat_factors(cfg) # OrderedDict to guarantee order of values() consistent with repeat factors dataset_name_to_dicts = OrderedDict( @@ -103,7 +105,28 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): cfg.DATASETS.TRAIN_REPEAT_FACTOR ) ) - sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors)) + repeat_factors = torch.tensor(repeat_factors) + if enable_category_balance: + """ + 1. Calculate repeat factors using category frequency for each dataset and then merge them. + 2. Element wise dot producting the dataset frequency repeat factors with + the category frequency repeat factors gives the final repeat factors. + """ + category_repeat_factors = [ + RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD + ) + for dataset_dict in dataset_name_to_dicts.values() + ] + # flatten the category repeat factors from all datasets + category_repeat_factors = list( + itertools.chain.from_iterable(category_repeat_factors) + ) + category_repeat_factors = torch.tensor(category_repeat_factors) + repeat_factors = torch.mul(category_repeat_factors, repeat_factors) + repeat_factors = repeat_factors / torch.min(repeat_factors) + + sampler = RepeatFactorTrainingSampler(repeat_factors) return build_batch_data_loader( dataset, @@ -149,7 +172,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work @fb_overwritable() def build_mapped_train_loader(cfg, mapper): if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler": + # balancing only datasets frequencies data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper) + elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler": + # balancing both datasets and its categories + data_loader = build_weighted_detection_train_loader( + cfg, mapper=mapper, enable_category_balance=True + ) else: data_loader = build_detection_train_loader(cfg, mapper=mapper) return data_loader