Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Weighted Sampler for highly imbalanced datasets #8766

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
75e3d76
add sampler
pourmand1376 Jul 28, 2022
0284aaf
add flag
pourmand1376 Jul 28, 2022
bdeeb2c
rank
pourmand1376 Jul 28, 2022
8e47fd7
assert
pourmand1376 Jul 28, 2022
1f2b073
add weighted sampler
pourmand1376 Jul 28, 2022
c78ebaa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2022
08dfe8b
fix bug
pourmand1376 Jul 28, 2022
0d73fb8
remove normalized count
pourmand1376 Jul 29, 2022
fa157e7
add validation check
pourmand1376 Jul 29, 2022
c577a9c
remove comment
pourmand1376 Jul 29, 2022
0ebde56
Merge branch 'master' into add_sampler
glenn-jocher Jul 29, 2022
475aeed
Merge branch 'master' into add_sampler
glenn-jocher Jul 29, 2022
1dc1c51
Merge branch 'master' into add_sampler
pourmand1376 Jul 31, 2022
0201b44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2022
9929f04
Merge branch 'master' into add_sampler
pourmand1376 Aug 1, 2022
2f07c9c
Merge branch 'master' into add_sampler
pourmand1376 Aug 2, 2022
d19b3ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2022
c4dc924
Merge branch 'master' into add_sampler
pourmand1376 Aug 3, 2022
2385e15
Merge branch 'master' into add_sampler
pourmand1376 Aug 10, 2022
9e5fb7c
Merge branch 'master' into add_sampler
pourmand1376 Aug 16, 2022
937e46d
Merge branch 'master' into add_sampler
pourmand1376 Aug 18, 2022
55c6e55
fix bug
pourmand1376 Aug 18, 2022
5384cad
Merge branch 'add_sampler' of https://github.com/pourmand1376/yolov5 …
pourmand1376 Aug 18, 2022
edc36ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2022
e803bec
Merge branch 'master' into add_sampler
pourmand1376 Aug 18, 2022
950c2d6
Merge branch 'master' into add_sampler
pourmand1376 Aug 18, 2022
7ad26de
Merge branch 'master' into add_sampler
pourmand1376 Aug 19, 2022
5aaf544
Merge branch 'master' into add_sampler
pourmand1376 Aug 20, 2022
ef34b5c
Merge branch 'master' into add_sampler
pourmand1376 Aug 21, 2022
7f4e58d
.
pourmand1376 Aug 21, 2022
65ef280
Merge branch 'add_sampler' of https://github.com/pourmand1376/yolov5 …
pourmand1376 Aug 21, 2022
a85e330
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2022
950bcb6
Merge branch 'master' into add_sampler
pourmand1376 Aug 21, 2022
e7fe107
add keyword arguments
pourmand1376 Aug 21, 2022
469636f
Merge branch 'master' into add_sampler
pourmand1376 Aug 21, 2022
97eb137
Merge branch 'master' into add_sampler
pourmand1376 Aug 21, 2022
5d0b07b
Merge branch 'master' into add_sampler
pourmand1376 Aug 25, 2022
2d4a541
Merge branch 'master' into add_sampler
pourmand1376 Aug 30, 2022
547bc31
Merge branch 'master' into add_sampler
pourmand1376 Aug 30, 2022
1878699
Merge branch 'master' into add_sampler
pourmand1376 Aug 31, 2022
b6ace34
Merge branch 'master' into add_sampler
pourmand1376 Nov 9, 2022
9a10a0d
Merge branch 'master' into add_sampler
pourmand1376 Nov 13, 2022
a354b3c
Merge branch 'master' into add_sampler
pourmand1376 Nov 26, 2022
b846905
Merge branch 'master' into add_sampler
pourmand1376 Dec 12, 2022
7bfb43c
Merge branch 'master' into add_sampler
pourmand1376 May 1, 2023
75e823c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2023
82eaa60
Update train.py
pourmand1376 May 1, 2023
f18d3cb
fix comman
pourmand1376 May 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
quad=opt.quad,
prefix=colorstr('train: '),
shuffle=True,
validation=False,
weighted_sampler=opt.weighted_sampler,
seed=opt.seed)
labels = np.concatenate(dataset.labels, 0)
mlc = int(labels[:, 0].max()) # max label class
Expand All @@ -220,7 +222,9 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
rank=-1,
workers=workers * 2,
pad=0.5,
prefix=colorstr('val: '))[0]
prefix=colorstr('val: '),
validation=True,
weighted_sampler=False)[0]

if not resume:
if not opt.noautoanchor:
Expand Down Expand Up @@ -470,6 +474,9 @@ def parse_opt(known=False):
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
parser.add_argument('--weighted_sampler',
action='store_true',
help='Use Weighted Sampler (for highly imbalanced data)')

# Logger arguments
parser.add_argument('--entity', default=None, help='Entity')
Expand Down
51 changes: 51 additions & 0 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from torch.utils.data.sampler import WeightedRandomSampler
from tqdm import tqdm

from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
Expand Down Expand Up @@ -100,6 +101,47 @@ def seed_worker(worker_id):
random.seed(worker_seed)


def create_weighted_sampler(dataset):
labels_per_class = [label[:, 0].tolist() for label in dataset.labels if label.shape[0] > 0]
# flatten 2d array into 1d: https://stackoverflow.com/questions/29244286/how-to-flatten-a-2d-list-to-1d-without-using-numpy
labels_per_class = [j for sub in labels_per_class for j in sub]

labels_per_class = np.array(labels_per_class)

background_count = len([1 for label in dataset.labels if label.shape[0] == 0])

unique_classes, counts = np.unique(labels_per_class, return_counts=True)

# = counts / (np.sum(counts) + background_count)
# normalized_background = background_count / (np.sum(counts) + background_count)

weight_cls = 1 / counts

# create a dictionary for the weight of each class
weight_dict = {}
for _cls, weight in zip(unique_classes, weight_cls):
weight_dict[_cls] = weight

weight_background = 1 / background_count

final_weights = []
for label in dataset.labels:
if label.shape[0] == 0:
final_weights.append(weight_background)
else:
# use weighted sum of labels for weight in case there are multiple labels for the same image
label_classes = np.unique(label[:, 0]).tolist()
values = []
for cls_ in label_classes:
values.append(weight_dict[_cls])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here shouldn't it be [cls_]? I assume this is a typo. Maybe worth noticing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Well noticed.

But this is not merged and it is not going to be since yolov5 is not maintained that much anymore.


final_weights.append(sum(values) / len(values))

final_weights = np.array(final_weights)
# you can set the num_samples argument to anything. It basically changes your iteration count in every epoch
return WeightedRandomSampler(weights=torch.from_numpy(final_weights), num_samples=len(final_weights))


def create_dataloader(path,
imgsz,
batch_size,
Expand All @@ -116,6 +158,8 @@ def create_dataloader(path,
quad=False,
prefix='',
shuffle=False,
validation=False,
weighted_sampler=False,
seed=0):
if rect and shuffle:
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
Expand All @@ -138,7 +182,14 @@ def create_dataloader(path,
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers

sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)

if not validation and weighted_sampler:
# weighted sampler should not be called on validation as this will report wrong results
assert rank == -1, 'Currently multi-GPU Support is not enabled when using weighted sampler'
sampler = create_weighted_sampler(dataset)

loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + seed + RANK)
Expand Down
4 changes: 3 additions & 1 deletion val.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def run(
pad=pad,
rect=rect,
workers=workers,
prefix=colorstr(f'{task}: '))[0]
prefix=colorstr(f'{task}: '),
validation=True,
weighted_sampler=False)[0]

seen = 0
confusion_matrix = ConfusionMatrix(nc=nc)
Expand Down