Skip to content

Commit

Permalink
Enable Class Balancing for Model Train Sampler
Browse files Browse the repository at this point in the history
Summary:
This diff enables both category and datasets weight balancing at the same time by declaring "WeightedCategoryTrainingSampler" under "SAMPLER_TRAIN" in config file.

X-link: facebookresearch/detectron2#4995

Pull Request resolved: facebookresearch#570

Differential Revision: D46377371

fbshipit-source-id: 8a1f92ae8d003198fab95fdb8b816803db41ae06
  • Loading branch information
dxzhou2023 authored and facebook-github-bot committed Jun 13, 2023
1 parent 3fce52c commit 53752af
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions d2go/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
return name_to_weight


def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
def build_weighted_detection_train_loader(
cfg: CfgNode, mapper=None, enable_category_balance=False
):
dataset_repeat_factors = get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict(
Expand Down Expand Up @@ -103,7 +105,28 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors))
repeat_factors = torch.tensor(repeat_factors)
if enable_category_balance:
"""
1. Calculate repeat factors using category frequency for each dataset and then merge them.
2. Element wise dot producting the dataset frequency repeat factors with
the category frequency repeat factors gives the final repeat factors.
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
)
for dataset_dict in dataset_name_to_dicts.values()
]
# flatten the category repeat factors from all datasets
category_repeat_factors = list(
itertools.chain.from_iterable(category_repeat_factors)
)
category_repeat_factors = torch.tensor(category_repeat_factors)
repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)

sampler = RepeatFactorTrainingSampler(repeat_factors)

return build_batch_data_loader(
dataset,
Expand Down Expand Up @@ -149,7 +172,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work
@fb_overwritable()
def build_mapped_train_loader(cfg, mapper):
if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler":
# balancing only datasets frequencies
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler":
# balancing both datasets and its categories
data_loader = build_weighted_detection_train_loader(
cfg, mapper=mapper, enable_category_balance=True
)
else:
data_loader = build_detection_train_loader(cfg, mapper=mapper)
return data_loader
Expand Down

0 comments on commit 53752af

Please sign in to comment.