Skip to content

Commit

Permalink
Enable Class Balancing for Model Train Sampler
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/detectron2#4995

Pull Request resolved: facebookresearch#570

Differential Revision: D46377371

fbshipit-source-id: c547997c71152d75011cbe94a0c762d20aa33669
  • Loading branch information
dxzhou2023 authored and facebook-github-bot committed Jun 13, 2023
1 parent 74825f6 commit b457ea9
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions d2go/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
return name_to_weight


def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
def build_weighted_detection_train_loader(
cfg: CfgNode, mapper=None, enable_category_balance=False
):
dataset_repeat_factors = get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict(
Expand Down Expand Up @@ -103,7 +105,27 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors))
repeat_factors = torch.tensor(repeat_factors)
if enable_category_balance:
"""
1. Calculate repeat factors using category frequency for each dataset.
2. Element wise dot producting the dataset frequency repeat factors with
the category frequency repeat factors gives the final repeat factors.
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
)
for dataset_dict in dataset_name_to_dicts.values()
]
# flatten the category repeat factors from all datasets
category_repeat_factors = list(
itertools.chain.from_iterable(category_repeat_factors)
)
repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)

sampler = RepeatFactorTrainingSampler(repeat_factors)

return build_batch_data_loader(
dataset,
Expand Down Expand Up @@ -149,7 +171,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work
@fb_overwritable()
def build_mapped_train_loader(cfg, mapper):
if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler":
# balancing only datasets frequencies
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler":
# balancing both datasets and its categories
data_loader = build_weighted_detection_train_loader(
cfg, mapper=mapper, enable_category_balance=True
)
else:
data_loader = build_detection_train_loader(cfg, mapper=mapper)
return data_loader
Expand Down

0 comments on commit b457ea9

Please sign in to comment.