Skip to content

Commit 77b4891

Browse files
[Fix] Fix a bug about multi-class in VideoDataset (open-mmlab#723)
* Fix 722 * add unittest and update changelog
1 parent afce5ca commit 77b4891

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

docs/changelog.md

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
**Bug and Typo Fixes**
1818

19+
- Fix a bug about multi-class in VideoDataset ([#723](https://github.com/open-mmlab/mmaction2/pull/678))
20+
1921
**ModelZoo**
2022

2123
- Add LFB for AVA2.1 ([#553](https://github.com/open-mmlab/mmaction2/pull/553))

mmaction/datasets/video_dataset.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os.path as osp
22

3-
import torch
4-
53
from .base import BaseDataset
64
from .registry import DATASETS
75

@@ -53,15 +51,10 @@ def load_annotations(self):
5351
assert self.num_classes is not None
5452
filename, label = line_split[0], line_split[1:]
5553
label = list(map(int, label))
56-
onehot = torch.zeros(self.num_classes)
57-
onehot[label] = 1.0
5854
else:
5955
filename, label = line_split
6056
label = int(label)
6157
if self.data_prefix is not None:
6258
filename = osp.join(self.data_prefix, filename)
63-
video_infos.append(
64-
dict(
65-
filename=filename,
66-
label=onehot if self.multi_class else label))
59+
video_infos.append(dict(filename=filename, label=label))
6760
return video_infos
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
test.mp4 0 3
2+
test.mp4 0 2 4

tests/test_data/test_datasets/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def setup_class(cls):
4141
'rawvideo_test_anno.txt')
4242
cls.video_ann_file = osp.join(cls.ann_file_prefix,
4343
'video_test_list.txt')
44+
cls.video_ann_file_multi_label = osp.join(
45+
cls.ann_file_prefix, 'video_test_list_multi_label.txt')
4446

4547
# pipeline configuration
4648
cls.action_pipeline = []

tests/test_data/test_datasets/test_video_dataset.py

+18
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ def test_video_dataset(self):
2828
assert video_infos == [dict(filename=video_filename, label=0)] * 2
2929
assert video_dataset.start_index == 0
3030

31+
def test_video_dataset_multi_label(self):
32+
video_dataset = VideoDataset(
33+
self.video_ann_file_multi_label,
34+
self.video_pipeline,
35+
data_prefix=self.data_prefix,
36+
multi_class=True,
37+
num_classes=100)
38+
video_infos = video_dataset.video_infos
39+
video_filename = osp.join(self.data_prefix, 'test.mp4')
40+
label0 = [0, 3]
41+
label1 = [0, 2, 4]
42+
labels = [label0, label1]
43+
for info, label in zip(video_infos, labels):
44+
print(info, video_filename)
45+
assert info['filename'] == video_filename
46+
assert set(info['label']) == set(label)
47+
assert video_dataset.start_index == 0
48+
3149
def test_video_pipeline(self):
3250
target_keys = ['filename', 'label', 'start_index', 'modality']
3351

0 commit comments

Comments
 (0)