diff --git a/configs/top_down/resnet/README.md b/configs/top_down/resnet/README.md index 58d8b7ac33..0a5089a179 100644 --- a/configs/top_down/resnet/README.md +++ b/configs/top_down/resnet/README.md @@ -29,3 +29,10 @@ | Arch | Input Size | Skeleton Acc | Contour Acc | Mean Acc | ckpt | log | | :--- | :--------: | :------: | :------: |:------: |:------: |:------: | | [pose_resnet_50](/configs/top_down/resnet/mpii_trb/res50_mpii_trb_256x256.py) | 256x256 | 0.884 | 0.855 | 0.865 | [ckpt](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmpose/top_down/resnet/res50_mpii_trb_256x256-f0305d2e_20200727.pth) | [log](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmpose/top_down/resnet/res50_mpii_trb_256x256_20200727.log.json) | + + +### Results on OneHand10K val set. + +| Arch | Input Size | PCK@0.2 | ckpt | log | +| :--- | :--------: | :------: | :------: |:------: | +| [pose_resnet_50](/configs/top_down/resnet/onehand10k/res50_onehand10k_256x256.py) | 256x256 | 0.985 | [ckpt](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmpose/top_down/resnet/res50_onehand10k_256x256-e67998f6_20200813.pth) | [log](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmpose/top_down/resnet/res50_onehand10k_256x256_20200813.log.json) | diff --git a/configs/top_down/resnet/onehand10k/res50_onehand10k_256x256.py b/configs/top_down/resnet/onehand10k/res50_onehand10k_256x256.py new file mode 100644 index 0000000000..583a2bda73 --- /dev/null +++ b/configs/top_down/resnet/onehand10k/res50_onehand10k_256x256.py @@ -0,0 +1,131 @@ +log_level = 'INFO' +load_from = None +resume_from = None +dist_params = dict(backend='nccl') +workflow = [('train', 1)] +checkpoint_config = dict(interval=10) +evaluation = dict(interval=1, metric='PCK') + +optimizer = dict( + type='Adam', + lr=5e-4, +) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +channel_cfg = dict( + num_output_channels=21, + dataset_joints=21, + dataset_channel=[ + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20 + ], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained='models/pytorch/imagenet/resnet50-19c8e357.pth', + backbone=dict(type='ResNet', depth=50), + keypoint_head=dict( + type='TopDownSimpleHead', + in_channels=2048, + out_channels=channel_cfg['num_output_channels'], + ), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process=True, + shift_heatmap=True, + unbiased_decoding=False, + modulate_kernel=11), + loss_pose=dict(type='JointsMSELoss', use_target_weight=True)) + +data_cfg = dict( + image_size=[256, 256], + heatmap_size=[64, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel']) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=20, scale_factor=0.3), + dict(type='TopDownAffine'), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict(type='TopDownGenerateTarget', sigma=2), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'flip_pairs' + ]), +] + +valid_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine'), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=[ + 'img', + ], + meta_keys=['image_file', 'center', 'scale', 'rotation', 'flip_pairs']), +] + +test_pipeline = valid_pipeline + +data_root = 'data/onehand10k' +data = dict( + samples_per_gpu=32, + workers_per_gpu=2, + train=dict( + type='TopDownOneHand10KDataset', + ann_file=f'{data_root}/annotations/onehand10k_train.json', + img_prefix=f'{data_root}/', + data_cfg=data_cfg, + pipeline=train_pipeline), + val=dict( + type='TopDownOneHand10KDataset', + ann_file=f'{data_root}/annotations/onehand10k_test.json', + img_prefix=f'{data_root}/', + data_cfg=data_cfg, + pipeline=valid_pipeline), + test=dict( + type='TopDownOneHand10KDataset', + ann_file=f'{data_root}/annotations/onehand10k_test.json', + img_prefix=f'{data_root}/', + data_cfg=data_cfg, + pipeline=valid_pipeline), +) diff --git a/docs/getting_started.md b/docs/getting_started.md index 835a13f34a..5f8c9d6083 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -86,6 +86,34 @@ mmpose ``` +**For OneHand10K data**, please download from [OneHand10K Dataset](https://www.yangangwang.com/papers/WANG-MCC-2018-10.html). +Please download the annotation files from [onehand10k_annotations](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmpose/datasets/onehand10k_annotations.tar). +Extract them under {MMPose}/data, and make them look like this: + +``` +mmpose +├── mmpose +├── docs +├── tests +├── tools +├── configs +`── data + │── onehand10k + |── annotations + | |── onehand10k_train.json + | |── onehand10k_test.json + `── Train + | |── source + | |── 0.jpg + | |── 1.jpg + | ... + `── Test + |── source + |── 0.jpg + |── 1.jpg + +``` + For using custom datasets, please refer to [Tutorial 2: Adding New Dataset](tutorials/new_dataset.md) ## Prepare Pretrained Models diff --git a/mmpose/datasets/__init__.py b/mmpose/datasets/__init__.py index 10782d27de..b1eb343456 100644 --- a/mmpose/datasets/__init__.py +++ b/mmpose/datasets/__init__.py @@ -1,12 +1,12 @@ from .builder import build_dataloader, build_dataset from .datasets import (BottomUpCocoDataset, TopDownCocoDataset, - TopDownMpiiTrbDataset) + TopDownMpiiTrbDataset, TopDownOneHand10KDataset) from .pipelines import Compose from .registry import DATASETS, PIPELINES from .samplers import DistributedSampler __all__ = [ 'TopDownCocoDataset', 'BottomUpCocoDataset', 'TopDownMpiiTrbDataset', - 'build_dataloader', 'build_dataset', 'Compose', 'DistributedSampler', - 'DATASETS', 'PIPELINES' + 'TopDownOneHand10KDataset', 'build_dataloader', 'build_dataset', 'Compose', + 'DistributedSampler', 'DATASETS', 'PIPELINES' ] diff --git a/mmpose/datasets/datasets/__init__.py b/mmpose/datasets/datasets/__init__.py index 143a547349..0893aa3455 100644 --- a/mmpose/datasets/datasets/__init__.py +++ b/mmpose/datasets/datasets/__init__.py @@ -1,6 +1,8 @@ from .bottom_up import BottomUpCocoDataset -from .top_down import TopDownCocoDataset, TopDownMpiiTrbDataset +from .top_down import (TopDownCocoDataset, TopDownMpiiTrbDataset, + TopDownOneHand10KDataset) __all__ = [ - 'TopDownCocoDataset', 'BottomUpCocoDataset', 'TopDownMpiiTrbDataset' + 'TopDownCocoDataset', 'BottomUpCocoDataset', 'TopDownMpiiTrbDataset', + 'TopDownOneHand10KDataset' ] diff --git a/mmpose/datasets/datasets/top_down/__init__.py b/mmpose/datasets/datasets/top_down/__init__.py index 75ac18b1d3..c85ae6d0d5 100644 --- a/mmpose/datasets/datasets/top_down/__init__.py +++ b/mmpose/datasets/datasets/top_down/__init__.py @@ -1,4 +1,7 @@ from .topdown_coco_dataset import TopDownCocoDataset from .topdown_mpii_trb_dataset import TopDownMpiiTrbDataset +from .topdown_onehand10k_dataset import TopDownOneHand10KDataset -__all__ = ['TopDownCocoDataset', 'TopDownMpiiTrbDataset'] +__all__ = [ + 'TopDownCocoDataset', 'TopDownMpiiTrbDataset', 'TopDownOneHand10KDataset' +] diff --git a/mmpose/datasets/datasets/top_down/topdown_onehand10k_dataset.py b/mmpose/datasets/datasets/top_down/topdown_onehand10k_dataset.py new file mode 100644 index 0000000000..d2364e8523 --- /dev/null +++ b/mmpose/datasets/datasets/top_down/topdown_onehand10k_dataset.py @@ -0,0 +1,233 @@ +import copy as cp +import os +import os.path as osp +from collections import OrderedDict + +import json_tricks as json +import numpy as np + +from mmpose.datasets.builder import DATASETS +from .topdown_base_dataset import TopDownBaseDataset + + +@DATASETS.register_module() +class TopDownOneHand10KDataset(TopDownBaseDataset): + """OneHand10K dataset for top-down hand pose estimation. + + The dataset loads raw features and apply specified transforms + to return a dict containing the image tensors and other information. + + OneHand10K keypoint indexes:: + + 0: 'wrist', + 1: 'thumb1', + 2: 'thumb2', + 3: 'thumb3', + 4: 'thumb4', + 5: 'forefinger1', + 6: 'forefinger2', + 7: 'forefinger3', + 8: 'forefinger4', + 9: 'middle_finger1', + 10: 'middle_finger2', + 11: 'middle_finger3', + 12: 'middle_finger4', + 13: 'ring_finger1', + 14: 'ring_finger2', + 15: 'ring_finger3', + 16: 'ring_finger4', + 17: 'pinky_finger1', + 18: 'pinky_finger2', + 19: 'pinky_finger3', + 20: 'pinky_finger4' + + Args: + ann_file (str): Path to the annotation file. + img_prefix (str): Path to a directory where images are held. + Default: None. + data_cfg (dict): config + pipeline (list[dict | callable]): A sequence of data transforms. + test_mode (bool): Store True when building test or + validation dataset. Default: False. + """ + + def __init__(self, + ann_file, + img_prefix, + data_cfg, + pipeline, + test_mode=False): + + super().__init__( + ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode) + + self.ann_info['flip_pairs'] = [] + + self.ann_info['use_different_joints_weight'] = False + assert self.ann_info['num_joints'] == 21 + self.ann_info['joints_weight'] = \ + np.ones((self.ann_info['num_joints'], 1), dtype=np.float32) + + self.db = self._get_db(ann_file) + self.image_set = set([x['image_file'] for x in self.db]) + self.num_images = len(self.image_set) + + print(f'=> num_images: {self.num_images}') + print(f'=> load {len(self.db)} samples') + + def _get_db(self, ann_file): + """Load dataset.""" + with open(ann_file, 'r') as f: + data = json.load(f) + tmpl = dict( + image_file=None, + center=None, + scale=None, + rotation=0, + joints_3d=None, + joints_3d_visible=None, + bbox=None, + dataset='OneHand10K') + + imid2info = {x['id']: x for x in data['images']} + + num_joints = self.ann_info['num_joints'] + gt_db = [] + + for anno in data['annotations']: + newitem = cp.deepcopy(tmpl) + image_id = anno['image_id'] + newitem['image_file'] = os.path.join( + self.img_prefix, imid2info[image_id]['file_name']) + + if max(anno['keypoints']) == 0: + continue + + joints_3d = np.zeros((num_joints, 3), dtype=np.float) + joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float) + + for ipt in range(num_joints): + joints_3d[ipt, 0] = anno['keypoints'][ipt * 3 + 0] + joints_3d[ipt, 1] = anno['keypoints'][ipt * 3 + 1] + joints_3d[ipt, 2] = 0 + t_vis = min(anno['keypoints'][ipt * 3 + 2], 1) + joints_3d_visible[ipt, :] = (t_vis, t_vis, 0) + + center, scale = self._xywh2cs(*anno['bbox'][:4]) + newitem['center'] = center + newitem['scale'] = scale + newitem['joints_3d'] = joints_3d + newitem['joints_3d_visible'] = joints_3d_visible + newitem['bbox'] = anno['bbox'][:4] + gt_db.append(newitem) + + return gt_db + + def _xywh2cs(self, x, y, w, h): + """This encodes bbox(x,y,w,w) into (center, scale) + + Args: + x, y, w, h + + Returns: + center (np.ndarray[float32](2,)): center of the bbox (x, y). + scale (np.ndarray[float32](2,)): scale of the bbox w & h. + """ + aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[ + 'image_size'][1] + center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32) + + if (not self.test_mode) and np.random.rand() < 0.3: + center += 0.4 * (np.random.rand(2) - 0.5) * [w, h] + + if w > aspect_ratio * h: + h = w * 1.0 / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + + # pixel std is 200.0 + scale = np.array([w / 200.0, h / 200.0], dtype=np.float32) + + scale = scale * 1.25 + + return center, scale + + def _evaluate_kernel(self, pred, joints_3d, joints_3d_visible, bbox): + """Evaluate one example. + + ||pre[i] - joints_3d[i]|| < 0.2 * max(w, h) + """ + num_joints = self.ann_info['num_joints'] + bbox = np.array(bbox) + threshold = np.max(bbox[2:]) * 0.2 + hit = np.zeros(num_joints, dtype=np.float32) + exist = np.zeros(num_joints, dtype=np.float32) + + for i in range(num_joints): + pred_pt = pred[i] + gt_pt = joints_3d[i] + vis = joints_3d_visible[i][0] + if vis: + exist[i] = 1 + else: + continue + distance = np.linalg.norm(pred_pt[:2] - gt_pt[:2]) + if distance < threshold: + hit[i] = 1 + return hit, exist + + def evaluate(self, outputs, res_folder, metrics='PCK', **kwargs): + """Evaluate OneHand10K keypoint results.""" + res_file = os.path.join(res_folder, 'result_keypoints.json') + + kpts = [] + + for preds, boxes, image_path in outputs: + str_image_path = ''.join(image_path) + image_id = int(osp.basename(osp.splitext(str_image_path)[0])) + + kpts.append({ + 'keypoints': preds[0].tolist(), + 'center': boxes[0][0:2].tolist(), + 'scale': boxes[0][2:4].tolist(), + 'area': float(boxes[0][4]), + 'score': float(boxes[0][5]), + 'image_id': image_id, + }) + + self._write_keypoint_results(kpts, res_file) + info_str = self._report_metric(res_file) + name_value = OrderedDict(info_str) + + return name_value + + def _write_keypoint_results(self, keypoints, res_file): + """Write results into a json file.""" + + with open(res_file, 'w') as f: + json.dump(keypoints, f, sort_keys=True, indent=4) + + def _report_metric(self, res_file): + """Keypoint evaluation. + + Report Mean Acc of skeleton, contour and all joints. + """ + num_joints = self.ann_info['num_joints'] + hit = np.zeros(num_joints, dtype=np.float32) + exist = np.zeros(num_joints, dtype=np.float32) + + with open(res_file, 'r') as fin: + preds = json.load(fin) + + assert len(preds) == len(self.db) + for pred, item in zip(preds, self.db): + h, e = self._evaluate_kernel(pred['keypoints'], item['joints_3d'], + item['joints_3d_visible'], + item['bbox']) + hit += h + exist += e + pck = np.sum(hit) / np.sum(exist) + + info_str = [] + info_str.append(('PCK', pck.item())) + return info_str diff --git a/tests/data/OneHand10K/160.jpg b/tests/data/OneHand10K/160.jpg new file mode 100755 index 0000000000..891919b576 Binary files /dev/null and b/tests/data/OneHand10K/160.jpg differ diff --git a/tests/data/OneHand10K/2251.jpg b/tests/data/OneHand10K/2251.jpg new file mode 100755 index 0000000000..df8d212c23 Binary files /dev/null and b/tests/data/OneHand10K/2251.jpg differ diff --git a/tests/data/OneHand10K/5.jpg b/tests/data/OneHand10K/5.jpg new file mode 100755 index 0000000000..7059d790fd Binary files /dev/null and b/tests/data/OneHand10K/5.jpg differ diff --git a/tests/data/OneHand10K/75.jpg b/tests/data/OneHand10K/75.jpg new file mode 100755 index 0000000000..c385278a94 Binary files /dev/null and b/tests/data/OneHand10K/75.jpg differ diff --git a/tests/data/OneHand10K/test_onehand10k.json b/tests/data/OneHand10K/test_onehand10k.json new file mode 100755 index 0000000000..aa9c6cee68 --- /dev/null +++ b/tests/data/OneHand10K/test_onehand10k.json @@ -0,0 +1,541 @@ +{ + "info": { + "description": "OneHand10K", + "version": "1.0", + "year": "2020", + "date_created": "2020/08/03" + }, + "licenses": "", + "images": [ + { + "file_name": "Train/source/5.jpg", + "height": 960, + "width": 1280, + "id": 5 + }, + { + "file_name": "Train/source/75.jpg", + "height": 330, + "width": 500, + "id": 75 + }, + { + "file_name": "Train/source/160.jpg", + "height": 157, + "width": 235, + "id": 160 + }, + { + "file_name": "Train/source/2251.jpg", + "height": 5312, + "width": 3984, + "id": 2251 + } + ], + "annotations": [ + { + "bbox": [ + 97, + 150, + 1057, + 622 + ], + "keypoints": [ + 149, + 432, + 1, + 367, + 241, + 1, + 598, + 207, + 1, + 695, + 196, + 1, + 820, + 199, + 1, + 760, + 414, + 1, + 905, + 444, + 1, + 1019, + 459, + 1, + 1102, + 489, + 1, + 737, + 522, + 1, + 887, + 529, + 1, + 1010, + 564, + 1, + 1153, + 592, + 1, + 713, + 616, + 1, + 829, + 642, + 1, + 917, + 664, + 1, + 1067, + 693, + 1, + 571, + 685, + 1, + 706, + 741, + 1, + 778, + 751, + 1, + 850, + 756, + 1 + ], + "category_id": 1, + "id": 5, + "image_id": 5, + "segmentation": [ + [ + 97, + 150, + 97, + 460.5, + 97, + 771, + 625.0, + 771, + 1153, + 771, + 1153, + 460.5, + 1153, + 150, + 625.0, + 150 + ] + ], + "iscrowd": 0, + "area": 657454 + }, + { + "bbox": [ + 82, + 106, + 300, + 176 + ], + "keypoints": [ + 134, + 219, + 1, + 177, + 165, + 1, + 215, + 135, + 1, + 247, + 124, + 1, + 276, + 122, + 1, + 272, + 152, + 1, + 306, + 135, + 1, + 333, + 119, + 1, + 362, + 110, + 1, + 275, + 171, + 1, + 318, + 152, + 1, + 353, + 137, + 1, + 380, + 125, + 1, + 281, + 200, + 1, + 315, + 190, + 1, + 345, + 184, + 1, + 377, + 169, + 1, + 277, + 229, + 1, + 310, + 229, + 1, + 329, + 222, + 1, + 349, + 211, + 1 + ], + "category_id": 1, + "id": 75, + "image_id": 75, + "segmentation": [ + [ + 82, + 106, + 82, + 193.5, + 82, + 281, + 231.5, + 281, + 381, + 281, + 381, + 193.5, + 381, + 106, + 231.5, + 106 + ] + ], + "iscrowd": 0, + "area": 52800 + }, + { + "bbox": [ + 29, + 21, + 84, + 136 + ], + "keypoints": [ + 57, + 151, + 1, + 79, + 139, + 1, + 80, + 110, + 1, + 82, + 71, + 1, + 84, + 52, + 1, + 57, + 73, + 1, + 57, + 39, + 1, + 75, + 28, + 1, + 94, + 26, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 112, + 98, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 105, + 100, + 1, + 94, + 116, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 105, + 116, + 1, + 96, + 127, + 1 + ], + "category_id": 1, + "id": 160, + "image_id": 160, + "segmentation": [ + [ + 29, + 21, + 29, + 88.5, + 29, + 156, + 70.5, + 156, + 112, + 156, + 112, + 88.5, + 112, + 21, + 70.5, + 21 + ] + ], + "iscrowd": 0, + "area": 11424 + }, + { + "bbox": [ + 426, + 1582, + 2437, + 1690 + ], + "keypoints": [ + 426, + 2117, + 1, + 1279, + 1949, + 1, + 1675, + 1745, + 1, + 2035, + 1745, + 1, + 2371, + 2033, + 1, + 1963, + 2141, + 1, + 0, + 0, + 0, + 2803, + 2513, + 1, + 2479, + 2765, + 1, + 0, + 0, + 0, + 2323, + 3005, + 1, + 1711, + 2705, + 1, + 1639, + 2429, + 1, + 0, + 0, + 0, + 2010, + 3161, + 1, + 1555, + 2849, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1507, + 3173, + 1, + 1339, + 2837, + 1, + 0, + 0, + 0 + ], + "category_id": 1, + "id": 2251, + "image_id": 2251, + "segmentation": [ + [ + 426, + 1582, + 426, + 2426.5, + 426, + 3271, + 1644.0, + 3271, + 2862, + 3271, + 2862, + 2426.5, + 2862, + 1582, + 1644.0, + 1582 + ] + ], + "iscrowd": 0, + "area": 4118530 + } + ], + "categories": [ + { + "supercategory": "hand", + "id": 1, + "name": "hand", + "keypoints": [ + "wrist", + "thumb1", + "thumb2", + "thumb3", + "thumb4", + "forefinger1", + "forefinger2", + "forefinger3", + "forefinger4", + "middle_finger1", + "middle_finger2", + "middle_finger3", + "middle_finger4", + "ring_finger1", + "ring_finger2", + "ring_finger3", + "ring_finger4", + "pinky_finger1", + "pinky_finger2", + "pinky_finger3", + "pinky_finger4" + ], + "skeleton": [ + [ + 1, + 2 + ], + [ + 2, + 3 + ], + [ + 3, + 4 + ], + [ + 4, + 5 + ], + [ + 1, + 6 + ], + [ + 6, + 7 + ], + [ + 7, + 8 + ], + [ + 8, + 9 + ], + [ + 1, + 10 + ], + [ + 10, + 11 + ], + [ + 11, + 12 + ], + [ + 12, + 13 + ], + [ + 1, + 14 + ], + [ + 14, + 15 + ], + [ + 15, + 16 + ], + [ + 16, + 17 + ], + [ + 1, + 18 + ], + [ + 18, + 19 + ], + [ + 19, + 20 + ], + [ + 20, + 21 + ] + ] + } + ] +} diff --git a/tests/test_datasets/test_top_down_dataset.py b/tests/test_datasets/test_top_down_dataset.py index 7841658ce3..5f8b73e81e 100644 --- a/tests/test_datasets/test_top_down_dataset.py +++ b/tests/test_datasets/test_top_down_dataset.py @@ -67,3 +67,50 @@ def test_top_down_COCO_dataset(): image_id = 785 assert image_id in custom_dataset.image_set_index assert len(custom_dataset.image_set_index) == 4 + + +def test_top_down_OneHand10K_dataset(): + dataset = 'TopDownOneHand10KDataset' + # test COCO datasets + dataset_class = DATASETS.get(dataset) + + channel_cfg = dict( + num_output_channels=21, + dataset_joints=21, + dataset_channel=[ + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20 + ], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20 + ]) + + data_cfg = dict( + image_size=[256, 256], + heatmap_size=[64, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel']) + # Test + data_cfg_copy = copy.deepcopy(data_cfg) + _ = dataset_class( + ann_file='tests/data/OneHand10K/test_onehand10k.json', + img_prefix='tests/data/OneHand10K/', + data_cfg=data_cfg_copy, + pipeline=[], + test_mode=True) + + custom_dataset = dataset_class( + ann_file='tests/data/OneHand10K/test_onehand10k.json', + img_prefix='tests/data/OneHand10K/', + data_cfg=data_cfg_copy, + pipeline=[], + test_mode=False) + + assert custom_dataset.test_mode is False + + assert custom_dataset.num_images == 4