Skip to content

Commit

Permalink
Enable Class Balancing for Model Train Sampler
Browse files Browse the repository at this point in the history
Summary:
This diff enables both category and datasets weight balancing at the same time by declaring "WeightedCategoryTrainingSampler" under "SAMPLER_TRAIN" in config file.

X-link: facebookresearch/detectron2#4995

Pull Request resolved: facebookresearch#570

Differential Revision: D46377371

fbshipit-source-id: 29f135e84f596e3e694a107bc676c2025e6e164f
  • Loading branch information
dxzhou2023 authored and facebook-github-bot committed Jun 15, 2023
1 parent 0389f4e commit 09e3827
Showing 1 changed file with 41 additions and 6 deletions.
47 changes: 41 additions & 6 deletions d2go/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
return name_to_weight


def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
def build_weighted_detection_train_loader(
cfg: CfgNode, mapper=None, enable_category_balance=False
):
dataset_repeat_factors = get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict(
Expand Down Expand Up @@ -98,12 +100,39 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper)

logger.info(
"Using WeightedTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
repeat_factors = torch.tensor(repeat_factors)
if enable_category_balance:
"""
1. Calculate repeat factors using category frequency for each dataset and then merge them.
2. Element wise dot producting the dataset frequency repeat factors with
the category frequency repeat factors gives the final repeat factors.
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
)
for dataset_dict in dataset_name_to_dicts.values()
]
# flatten the category repeat factors from all datasets
category_repeat_factors = list(
itertools.chain.from_iterable(category_repeat_factors)
)
category_repeat_factors = torch.tensor(category_repeat_factors)
repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)
logger.info(
"Using WeightedCategoryTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
else:
logger.info(
"Using WeightedTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
)
sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors))

sampler = RepeatFactorTrainingSampler(repeat_factors)

return build_batch_data_loader(
dataset,
Expand Down Expand Up @@ -149,7 +178,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work
@fb_overwritable()
def build_mapped_train_loader(cfg, mapper):
if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler":
# balancing only datasets frequencies
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler":
# balancing both datasets and its categories
data_loader = build_weighted_detection_train_loader(
cfg, mapper=mapper, enable_category_balance=True
)
else:
data_loader = build_detection_train_loader(cfg, mapper=mapper)
return data_loader
Expand Down

0 comments on commit 09e3827

Please sign in to comment.