Skip to content

Commit

Permalink
Enable Class Balancing for Model Train Sampler
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/detectron2#4995

Pull Request resolved: facebookresearch#570

Differential Revision: D46377371

fbshipit-source-id: 8e5e8b859a77eb291bfa30e54ce45126e0d2cd60
  • Loading branch information
dxzhou2023 authored and facebook-github-bot committed Jun 12, 2023
1 parent ee5ae5e commit 99f6b55
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions d2go/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
return name_to_weight


def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
def build_weighted_detection_train_loader(
cfg: CfgNode, mapper=None, enable_category_balance=False
):
dataset_repeat_factors = get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict(
Expand Down Expand Up @@ -103,7 +105,22 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors))
repeat_factors = torch.tensor(repeat_factors)
if enable_category_balance:
"""
Current repeat factor is already balanced by dataset frequency
Use annotations information in dataset_dicts to compute repeat factors to balance category frequencies.
Element wise dot producting these two list of repeat factors gives the final balanced repeat factors
"""
class_balance_repeat_factors = (
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
)
)
repeat_factors = torch.mul(class_balance_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)

sampler = RepeatFactorTrainingSampler(repeat_factors)

return build_batch_data_loader(
dataset,
Expand Down Expand Up @@ -149,7 +166,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work
@fb_overwritable()
def build_mapped_train_loader(cfg, mapper):
if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler":
# balancing only datasets frequencies
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler":
# balancing both datasets and its categories
data_loader = build_weighted_detection_train_loader(
cfg, mapper=mapper, enable_class_balance=True
)
else:
data_loader = build_detection_train_loader(cfg, mapper=mapper)
return data_loader
Expand Down

0 comments on commit 99f6b55

Please sign in to comment.