Skip to content

Commit

Permalink
修改coco_stuff和pascal_context标注转换问题 (#1257)
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyuying authored Aug 19, 2021
1 parent 37d7745 commit 0e97663
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 19 deletions.
25 changes: 21 additions & 4 deletions docs/data_prepare.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ export PYTHONPATH=`pwd`
其涵盖了150个语义类别,包括训练集20210张,验证集2000张。

## 关于Coco Stuff数据集
Coco Stuff是基于Coco数据集的像素级别语义分割数据集。它主要覆盖172个类别,包含80个'thing',91个'stuff'和1个'unlabeled',
其中训练集118k, 验证集5k.
Coco Stuff是基于Coco数据集的像素级别语义分割数据集。它主要覆盖172个类别,包含80个'thing',91个'stuff'和1个'unlabeled',我们忽略'unlabeled'类别,并将其index设为255,不记录损失。因此提供的训练版本为171个类别。其中,训练集118k, 验证集5k.

在使用Coco Stuff数据集前, 请自行前往[COCO-Stuff主页](https://github.com/nightrome/cocostuff)下载数据集,或者下载[coco2017训练集原图](http://images.cocodataset.org/zips/train2017.zip), [coco2017验证集原图](http://images.cocodataset.org/zips/val2017.zip)[标注图](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
我们建议您将数据集存放于`PaddleSeg/data`中,以便与我们配置文件完全兼容。数据集下载后请组织成如下结构:
Expand All @@ -65,10 +64,18 @@ Coco Stuff是基于Coco数据集的像素级别语义分割数据集。它主要
| |--train2017
| |--val2017

运行下列命令进行标签转换:

```shell
python tools/convert_cocostuff.py --annotation_path /PATH/TO/ANNOTATIONS --save_path /PATH/TO/CONVERT_ANNOTATIONS
```
其中`annotation_path`应根据下载cocostuff/annotations文件夹的实际路径填写。 `save_path`决定转换后标签的存放位置。


其中,标注图像的标签从0,1依次取值,不可间隔。若有需要忽略的像素,则按255进行标注。

## 关于Pascal Context数据集
Pascal Context是基于PASCAL VOC 2010数据集额外标注的像素级别的语义分割数据集。我们提供的转换脚本支持59个类别,其中训练集4996, 验证集5104张.
Pascal Context是基于PASCAL VOC 2010数据集额外标注的像素级别的语义分割数据集。我们提供的转换脚本支持60个类别,index为0是背景类别。该数据集中中训练集4996, 验证集5104张.


在使用Pascal Context数据集前, 请先下载[VOC2010](http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar),随后自行前往[Pascal-Context主页](https://www.cs.stanford.edu/~roozbeh/pascal-context/)下载数据集及[标注](https://codalabuser.blob.core.windows.net/public/trainval_merged.json)
Expand All @@ -87,8 +94,18 @@ Pascal Context是基于PASCAL VOC 2010数据集额外标注的像素级别的语
|--SegmentationObject
|
|--trainval_merged.json


运行下列命令进行标签转换:

```shell
python tools/convert_voc2010.py --voc_path /PATH/TO/VOC ----annotation_path /PATH/TO/JSON
```
其中`voc_path`应根据下载VOC2010文件夹的实际路径填写。 `annotation_path`决定下载trainval_merged.json的存放位置。



其中,标注图像的标签从1,2依次取值,不可间隔。若有需要忽略的像素,则按255(默认的忽略值)进行标注。在使用Pascal Context数据集时,需要安装[Detail](https://github.com/zhanghang1989/detail-api).
其中,标注图像的标签从0,1,2依次取值,不可间隔。若有需要忽略的像素,则按255(默认的忽略值)进行标注。在使用Pascal Context数据集时,需要安装[Detail](https://github.com/zhanghang1989/detail-api).

## 自定义数据集

Expand Down
4 changes: 3 additions & 1 deletion paddleseg/datasets/cocostuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,19 @@ class CocoStuff(Dataset):
transforms (list): Transforms for image.
dataset_root (str): Cityscapes dataset directory.
mode (str): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'.
edge (bool, optional): Whether to compute edge while training. Default: False
"""
NUM_CLASSES = 171

def __init__(self, transforms, dataset_root, mode='train'):
def __init__(self, transforms, dataset_root, mode='train', edge=False):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
self.file_list = list()
mode = mode.lower()
self.mode = mode
self.num_classes = self.NUM_CLASSES
self.ignore_index = 255
self.edge = edge

if mode not in ['train', 'val']:
raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions paddleseg/datasets/pascal_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ class PascalContext(Dataset):
dataset_root (str): The dataset directory. Default: None
mode (str): Which part of dataset to use. it is one of ('train', 'trainval', 'context', 'val').
If you want to set mode to 'context', please make sure the dataset have been augmented. Default: 'train'.
edge (bool, optional): Whether to compute edge while training. Default: False
"""
NUM_CLASSES = 59
NUM_CLASSES = 60

def __init__(self, transforms=None, dataset_root=None, mode='train'):
def __init__(self, transforms=None, dataset_root=None, mode='train', edge=False):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
mode = mode.lower()
self.mode = mode
self.file_list = list()
self.num_classes = self.NUM_CLASSES
self.ignore_index = 255
self.edge = edge

if mode not in ['train', 'trainval', 'val']:
raise ValueError(
Expand Down
102 changes: 102 additions & 0 deletions tools/convert_cocostuff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: convert_cocostuff.py
This file is based on https://github.com/nightrome/cocostuff to generate PASCAL-Context Dataset.
Before running, you should download the COCOSTUFF from https://github.com/nightrome/cocostuff. Then, make the folder
structure as follow:
cocostuff
|
|--images
| |--train2017
| |--val2017
|
|--annotations
| |--train2017
| |--val2017
"""

import os
import argparse

import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm


def parse_args():
parser = argparse.ArgumentParser(
description='Generate COCOStuff dataset')
parser.add_argument(
'--annotation_path', default='annotations', help='COCOStuff anotation path', type=str)
parser.add_argument(
'--save_path', default='convert_annotations', help='COCOStuff anotation path', type=str)

return parser.parse_args()


class COCOStuffGenerator(object):
def __init__(self, annotation_path, save_path):

super(COCOStuffGenerator, self).__init__()

self.mapping = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 66, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181]
self.annotation_path = annotation_path
self.save_path = save_path

def encode_label(self, labelmap):
ret = np.ones_like(labelmap) * 255
for idx, label in enumerate(self.mapping):

ret[labelmap == label] = idx
return ret.astype(np.uint8)

def generate_label(self):
train_path = os.path.join(self.annotation_path, 'train2017')
val_path = os.path.join(self.annotation_path, 'val2017')
save_train_path = os.path.join(self.save_path, 'train2017')
save_val_path = os.path.join(self.save_path, 'val2017')

if not os.path.exists(save_train_path):
os.makedirs(save_train_path)
if not os.path.exists(save_val_path):
os.makedirs(save_val_path)

for label_id in tqdm(os.listdir(train_path), desc='trainset'):
label = np.array(
Image.open(os.path.join(train_path, label_id)).convert('P')
)
label = self.encode_label(label)
label = Image.fromarray(label)
label.save(os.path.join(save_train_path, label_id))

for label_id in tqdm(os.listdir(val_path), desc='valset'):
label = np.array(
Image.open(os.path.join(val_path, label_id)).convert('P')
)
label = self.encode_label(label)
label = Image.fromarray(label)
label.save(os.path.join(save_val_path, label_id))

def main():
args = parse_args()
generator = COCOStuffGenerator(
annotation_path=args.annotation_path, save_path=args.save_path)
generator.generate_label()


if __name__ == '__main__':
main()
#/mnt/haoyuying/data/cocostuff/convert_annotations/val2017/000000086336.png
15 changes: 3 additions & 12 deletions tools/convert_voc2010.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"""
File: convert_voc2010.py
This file is based on https://www.cs.stanford.edu/~roozbeh/pascal-context/ to generate PASCAL-Context Dataset.
Before running, you should download the PASCAL VOC2010 from http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar, PASCAL-Context Dataset from https://www.cs.stanford.edu/~roozbeh/pascal-context/ and annotation file from https://codalabuser.blob.core.windows.net/public/trainval_merged.json. Then, make the folder
structure as follow:
Before running, you should download the PASCAL VOC2010 from http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar, PASCAL-Context label id from https://www.cs.stanford.edu/~roozbeh/pascal-context/ and annotation file from https://codalabuser.blob.core.windows.net/public/trainval_merged.json. In segmentation map annotation for PascalContext, 0 stands for background, which is included in 60 categories. Then, make the folder structure as follow:
VOC2010
|
|--Annotations
Expand All @@ -38,9 +38,6 @@
import numpy as np
from detail import Detail
from PIL import Image
from paddleseg.utils.download import _download_file

JSON_URL = 'https://codalabuser.blob.core.windows.net/public/trainval_merged.json'


def parse_args():
Expand All @@ -66,12 +63,6 @@ def __init__(self, voc_path, annotation_path):
self.annFile = os.path.join(self.annotation_path,
'trainval_merged.json')

if not os.path.exists(self.annFile):
_download_file(
url=JSON_URL,
savepath=self.annotation_path,
print_progress=True)

self._mapping = np.sort(
np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25,
Expand All @@ -80,7 +71,7 @@ def __init__(self, voc_path, annotation_path):
34, 207, 80, 355, 85, 347, 220, 349, 360, 98, 187, 104, 105,
366, 189, 368, 113, 115
]))
self._key = np.array(range(len(self._mapping))).astype('uint8') - 1
self._key = np.array(range(len(self._mapping))).astype('uint8')

self.train_detail = Detail(self.annFile, self._image_dir, 'train')
self.train_ids = self.train_detail.getImgs()
Expand Down

0 comments on commit 0e97663

Please sign in to comment.