From 7a20d83f27bf04aca42f0028f8ad43c38da2d2da Mon Sep 17 00:00:00 2001 From: felixhjh <852142024@qq.com> Date: Sun, 24 Jul 2022 13:32:58 +0000 Subject: [PATCH 1/7] Detection evaluation function --- fastdeploy/vision/evaluation/__init__.py | 1 + fastdeploy/vision/evaluation/detection.py | 66 ++++++ .../vision/evaluation/utils/__init__.py | 8 + fastdeploy/vision/evaluation/utils/coco.py | 179 +++++++++++++++ .../vision/evaluation/utils/coco_utils.py | 217 ++++++++++++++++++ .../vision/evaluation/utils/fd_logging.py | 53 +++++ .../vision/evaluation/utils/json_results.py | 155 +++++++++++++ .../vision/evaluation/utils/map_utils.py | 27 +++ fastdeploy/vision/evaluation/utils/metrics.py | 75 ++++++ fastdeploy/vision/evaluation/utils/util.py | 20 ++ 10 files changed, 801 insertions(+) create mode 100644 fastdeploy/vision/evaluation/detection.py create mode 100644 fastdeploy/vision/evaluation/utils/__init__.py create mode 100644 fastdeploy/vision/evaluation/utils/coco.py create mode 100644 fastdeploy/vision/evaluation/utils/coco_utils.py create mode 100644 fastdeploy/vision/evaluation/utils/fd_logging.py create mode 100644 fastdeploy/vision/evaluation/utils/json_results.py create mode 100644 fastdeploy/vision/evaluation/utils/map_utils.py create mode 100644 fastdeploy/vision/evaluation/utils/metrics.py create mode 100644 fastdeploy/vision/evaluation/utils/util.py diff --git a/fastdeploy/vision/evaluation/__init__.py b/fastdeploy/vision/evaluation/__init__.py index 1158095ec5..d2c9a79116 100644 --- a/fastdeploy/vision/evaluation/__init__.py +++ b/fastdeploy/vision/evaluation/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from __future__ import absolute_import from .classify import eval_classify +from .detection import eval_detection diff --git a/fastdeploy/vision/evaluation/detection.py b/fastdeploy/vision/evaluation/detection.py new file mode 100644 index 0000000000..4aaaaaaa56 --- /dev/null +++ b/fastdeploy/vision/evaluation/detection.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import trange +import cv2 +import numpy as np +from .utils import CocoDetection +from .utils import COCOMetric +import copy +import collections + + +def eval_detection(model, + conf_threshold, + nms_iou_threshold, + data_dir, + ann_file, + plot=False): + assert isinstance(conf_threshold, ( + float, int + )), "The conf_threshold:{} need to be int or float".format(conf_threshold) + assert isinstance(nms_iou_threshold, ( + float, + int)), "The nms_iou_threshold:{} need to be int or float".format( + nms_iou_threshold) + eval_dataset = CocoDetection( + data_dir=data_dir, ann_file=ann_file, shuffle=False) + all_image_info = eval_dataset.file_list + image_num = eval_dataset.num_samples + eval_dataset.data_fields = { + 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class', 'is_crowd' + } + eval_metric = COCOMetric( + coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False) + scores = collections.OrderedDict() + for image_info, i in zip(all_image_info, + trange( + image_num, desc="Inference Progress")): + im = cv2.imread(image_info["image"]) + im_id = image_info["im_id"] + result = model.predict(im, conf_threshold, nms_iou_threshold) + pred = { + 'bbox': + [[c] + [s] + b + for b, s, c in zip(result.boxes, result.scores, result.label_ids) + ], + 'bbox_num': len(result.boxes), + 'im_id': im_id + } + eval_metric.update(im_id, pred) + eval_metric.accumulate() + eval_details = eval_metric.details + scores.update(eval_metric.get()) + eval_metric.reset() + return scores diff --git a/fastdeploy/vision/evaluation/utils/__init__.py b/fastdeploy/vision/evaluation/utils/__init__.py new file mode 100644 index 0000000000..4536b491d8 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/__init__.py @@ -0,0 +1,8 @@ +from . import fd_logging +from .util import * +from .metrics import * +from .json_results import * +from .map_utils import * +from .coco_utils import * +from .coco import * +from .cityscapes import Cityscapes diff --git a/fastdeploy/vision/evaluation/utils/coco.py b/fastdeploy/vision/evaluation/utils/coco.py new file mode 100644 index 0000000000..079cde0c22 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/coco.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import copy +import os.path as osp +import six +import sys +import numpy as np +from . import fd_logging as logging +from .util import is_pic, get_num_workers + + +class CocoDetection(object): + """读取MSCOCO格式的检测数据集,并对样本进行相应的处理,该格式的数据集同样可以应用到实例分割模型的训练中。 + + Args: + data_dir (str): 数据集所在的目录路径。 + ann_file (str): 数据集的标注文件,为一个独立的json格式文件。 + num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据 + 系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的一半。 + shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。 + allow_empty (bool): 是否加载负样本。默认为False。 + empty_ratio (float): 用于指定负样本占总样本数的比例。如果小于0或大于等于1,则保留全部的负样本。默认为1。 + """ + + def __init__(self, + data_dir, + ann_file, + num_workers='auto', + shuffle=False, + allow_empty=False, + empty_ratio=1.): + + from pycocotools.coco import COCO + self.data_dir = data_dir + self.data_fields = None + self.num_max_boxes = 1000 + self.num_workers = get_num_workers(num_workers) + self.shuffle = shuffle + self.allow_empty = allow_empty + self.empty_ratio = empty_ratio + self.file_list = list() + neg_file_list = list() + self.labels = list() + + coco = COCO(ann_file) + self.coco_gt = coco + img_ids = sorted(coco.getImgIds()) + cat_ids = coco.getCatIds() + catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) + cname2clsid = dict({ + coco.loadCats(catid)[0]['name']: clsid + for catid, clsid in catid2clsid.items() + }) + for label, cid in sorted(cname2clsid.items(), key=lambda d: d[1]): + self.labels.append(label) + logging.info("Starting to read file list from dataset...") + + ct = 0 + for img_id in img_ids: + is_empty = False + img_anno = coco.loadImgs(img_id)[0] + im_fname = osp.join(data_dir, img_anno['file_name']) + if not is_pic(im_fname): + continue + im_w = float(img_anno['width']) + im_h = float(img_anno['height']) + ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) + instances = coco.loadAnns(ins_anno_ids) + + bboxes = [] + for inst in instances: + x, y, box_w, box_h = inst['bbox'] + x1 = max(0, x) + y1 = max(0, y) + x2 = min(im_w - 1, x1 + max(0, box_w)) + y2 = min(im_h - 1, y1 + max(0, box_h)) + if inst['area'] > 0 and x2 >= x1 and y2 >= y1: + inst['clean_bbox'] = [x1, y1, x2, y2] + bboxes.append(inst) + else: + logging.warning( + "Found an invalid bbox in annotations: " + "im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}." + .format(img_id, float(inst['area']), x1, y1, x2, y2)) + num_bbox = len(bboxes) + if num_bbox == 0 and not self.allow_empty: + continue + elif num_bbox == 0: + is_empty = True + + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + gt_score = np.ones((num_bbox, 1), dtype=np.float32) + is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) + difficult = np.zeros((num_bbox, 1), dtype=np.int32) + gt_poly = [None] * num_bbox + + has_segmentation = False + for i, box in reversed(list(enumerate(bboxes))): + catid = box['category_id'] + gt_class[i][0] = catid2clsid[catid] + gt_bbox[i, :] = box['clean_bbox'] + is_crowd[i][0] = box['iscrowd'] + if 'segmentation' in box and box['iscrowd'] == 1: + gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] + elif 'segmentation' in box and box['segmentation']: + if not np.array( + box['segmentation'], + dtype=object).size > 0 and not self.allow_empty: + gt_poly.pop(i) + is_crowd = np.delete(is_crowd, i) + gt_class = np.delete(gt_class, i) + gt_bbox = np.delete(gt_bbox, i) + else: + gt_poly[i] = box['segmentation'] + has_segmentation = True + if has_segmentation and not any(gt_poly) and not self.allow_empty: + continue + + im_info = { + 'im_id': np.array([img_id]).astype('int32'), + 'image_shape': np.array([im_h, im_w]).astype('int32'), + } + label_info = { + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + 'gt_score': gt_score, + 'gt_poly': gt_poly, + 'difficult': difficult + } + + if is_empty: + neg_file_list.append({ + 'image': im_fname, + ** + im_info, + ** + label_info + }) + else: + self.file_list.append({ + 'image': im_fname, + ** + im_info, + ** + label_info + }) + ct += 1 + + self.num_max_boxes = max(self.num_max_boxes, len(instances)) + + if not ct: + logging.error( + "No coco record found in %s' % (ann_file)", exit=True) + self.pos_num = len(self.file_list) + if self.allow_empty and neg_file_list: + self.file_list += self._sample_empty(neg_file_list) + logging.info( + "{} samples in file {}, including {} positive samples and {} negative samples.". + format( + len(self.file_list), ann_file, self.pos_num, + len(self.file_list) - self.pos_num)) + self.num_samples = len(self.file_list) + + self._epoch = 0 diff --git a/fastdeploy/vision/evaluation/utils/coco_utils.py b/fastdeploy/vision/evaluation/utils/coco_utils.py new file mode 100644 index 0000000000..66282148b1 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/coco_utils.py @@ -0,0 +1,217 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import numpy as np +import itertools +from .map_utils import draw_pr_curve +from .json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res +import logging as logging +import copy + + +def loadRes(coco_obj, anns): + """ + Load result file and return a result api object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + + # This function has the same functionality as pycocotools.COCO.loadRes, + # except that the input anns is list of results rather than a json file. + # Refer to + # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305, + + # matplotlib.use() must be called *before* pylab, matplotlib.pyplot, + # or matplotlib.backends is imported for the first time + # pycocotools import matplotlib + import matplotlib + matplotlib.use('Agg') + from pycocotools.coco import COCO + import pycocotools.mask as maskUtils + import time + res = COCO() + res.dataset['images'] = [img for img in coco_obj.dataset['images']] + + tic = time.time() + assert type(anns) == list, 'results in not an array of objects' + annsImgIds = [ann['image_id'] for ann in anns] + assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \ + 'Results do not correspond to current coco set' + if 'caption' in anns[0]: + imgIds = set([img['id'] for img in res.dataset['images']]) & set( + [ann['image_id'] for ann in anns]) + res.dataset['images'] = [ + img for img in res.dataset['images'] if img['id'] in imgIds + ] + for id, ann in enumerate(anns): + ann['id'] = id + 1 + elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: + res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[ + 'categories']) + for id, ann in enumerate(anns): + bb = ann['bbox'] + x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] + if not 'segmentation' in ann: + ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann['area'] = bb[2] * bb[3] + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'segmentation' in anns[0]: + res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[ + 'categories']) + for id, ann in enumerate(anns): + # now only support compressed RLE format as segmentation results + ann['area'] = maskUtils.area(ann['segmentation']) + if not 'bbox' in ann: + ann['bbox'] = maskUtils.toBbox(ann['segmentation']) + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'keypoints' in anns[0]: + res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[ + 'categories']) + for id, ann in enumerate(anns): + s = ann['keypoints'] + x = s[0::3] + y = s[1::3] + x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y) + ann['area'] = (x1 - x0) * (y1 - y0) + ann['id'] = id + 1 + ann['bbox'] = [x0, y0, x1 - x0, y1 - y0] + + res.dataset['annotations'] = anns + res.createIndex() + return res + + +def get_infer_results(outs, catid, bias=0): + """ + Get result at the stage of inference. + The output format is dictionary containing bbox or mask result. + + For example, bbox result is a list and each element contains + image_id, category_id, bbox and score. + """ + if outs is None or len(outs) == 0: + raise ValueError( + 'The number of valid detection result if zero. Please use reasonable model and check input data.' + ) + + im_id = outs['im_id'] + + infer_res = {} + if 'bbox' in outs: + if len(outs['bbox']) > 0 and len(outs['bbox'][0]) > 6: + infer_res['bbox'] = get_det_poly_res( + outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias) + else: + infer_res['bbox'] = get_det_res( + outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias) + + if 'mask' in outs: + # mask post process + infer_res['mask'] = get_seg_res(outs['mask'], outs['bbox'], + outs['bbox_num'], im_id, catid) + + if 'segm' in outs: + infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid) + + return infer_res + + +def cocoapi_eval(anns, + style, + coco_gt=None, + anno_file=None, + max_dets=(100, 300, 1000), + classwise=False): + """ + Args: + anns: Evaluation result. + style (str): COCOeval style, can be `bbox` , `segm` and `proposal`. + coco_gt (str): Whether to load COCOAPI through anno_file, + eg: coco_gt = COCO(anno_file) + anno_file (str): COCO annotations file. + max_dets (tuple): COCO evaluation maxDets. + classwise (bool): Whether per-category AP and draw P-R Curve or not. + """ + assert coco_gt is not None or anno_file is not None + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + if coco_gt is None: + coco_gt = COCO(anno_file) + logging.info("Start evaluate...") + coco_dt = loadRes(coco_gt, anns) + if style == 'proposal': + coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') + coco_eval.params.useCats = 0 + coco_eval.params.maxDets = list(max_dets) + else: + coco_eval = COCOeval(coco_gt, coco_dt, style) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if classwise: + # Compute per-category AP and PR curve + try: + from terminaltables import AsciiTable + except Exception as e: + logging.error( + 'terminaltables not found, plaese install terminaltables. ' + 'for example: `pip install terminaltables`.') + raise e + precisions = coco_eval.eval['precision'] + cat_ids = coco_gt.getCatIds() + # precision: (iou, recall, cls, area range, max dets) + assert len(cat_ids) == precisions.shape[2] + results_per_category = [] + for idx, catId in enumerate(cat_ids): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = coco_gt.loadCats(catId)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + results_per_category.append( + (str(nm["name"]), '{:0.3f}'.format(float(ap)))) + pr_array = precisions[0, :, idx, 0, 2] + recall_array = np.arange(0.0, 1.01, 0.01) + draw_pr_curve( + pr_array, + recall_array, + out_dir=style + '_pr_curve', + file_name='{}_precision_recall_curve.jpg'.format(nm["name"])) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + headers = ['category', 'AP'] * (num_columns // 2) + results_2d = itertools.zip_longest( + * [results_flatten[i::num_columns] for i in range(num_columns)]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + logging.info('Per-category of {} AP: \n{}'.format(style, table.table)) + logging.info("per-category PR curve has output to {} folder.".format( + style + '_pr_curve')) + # flush coco evaluation result + sys.stdout.flush() + return coco_eval.stats diff --git a/fastdeploy/vision/evaluation/utils/fd_logging.py b/fastdeploy/vision/evaluation/utils/fd_logging.py new file mode 100644 index 0000000000..02c0b5d024 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/fd_logging.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import os +import sys +import colorama +from colorama import init + +init(autoreset=True) +levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} + + +def log(level=2, message="", use_color=False): + current_time = time.time() + time_array = time.localtime(current_time) + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) + if use_color: + print("\033[1;31;40m{} [{}]\t{}\033[0m".format(current_time, levels[ + level], message).encode("utf-8").decode("latin1")) + else: + print("{} [{}]\t{}".format(current_time, levels[level], message) + .encode("utf-8").decode("latin1")) + sys.stdout.flush() + + +def debug(message="", use_color=False): + log(level=3, message=message, use_color=use_color) + + +def info(message="", use_color=False): + log(level=2, message=message, use_color=use_color) + + +def warning(message="", use_color=True): + log(level=1, message=message, use_color=use_color) + + +def error(message="", use_color=True, exit=True): + log(level=0, message=message, use_color=use_color) + if exit: + sys.exit(-1) diff --git a/fastdeploy/vision/evaluation/utils/json_results.py b/fastdeploy/vision/evaluation/utils/json_results.py new file mode 100644 index 0000000000..c144111b5b --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/json_results.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import six +import numpy as np + + +def get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): + det_res = [] + for i in range(bbox_nums): + cur_image_id = int(image_id) + dt = bboxes[i] + num_id, score, xmin, ymin, xmax, ymax = dt + if int(num_id) < 0: + continue + category_id = label_to_cat_id_map[int(num_id)] + w = xmax - xmin + bias + h = ymax - ymin + bias + bbox = [xmin, ymin, w, h] + dt_res = { + 'image_id': cur_image_id, + 'category_id': category_id, + 'bbox': bbox, + 'score': score + } + det_res.append(dt_res) + return det_res + + +def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): + det_res = [] + k = 0 + for i in range(len(bbox_nums)): + cur_image_id = int(image_id[i][0]) + det_nums = bbox_nums[i] + for j in range(det_nums): + dt = bboxes[k] + k = k + 1 + num_id, score, x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist() + if int(num_id) < 0: + continue + category_id = label_to_cat_id_map[int(num_id)] + rbox = [x1, y1, x2, y2, x3, y3, x4, y4] + dt_res = { + 'image_id': cur_image_id, + 'category_id': category_id, + 'bbox': rbox, + 'score': score + } + det_res.append(dt_res) + return det_res + + +def strip_mask(mask): + row = mask[0, 0, :] + col = mask[0, :, 0] + im_h = len(col) - np.count_nonzero(col == -1) + im_w = len(row) - np.count_nonzero(row == -1) + return mask[:, :im_h, :im_w] + + +def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map): + import pycocotools.mask as mask_util + seg_res = [] + k = 0 + for i in range(len(mask_nums)): + cur_image_id = int(image_id[i][0]) + det_nums = mask_nums[i] + mask_i = masks[k:k + det_nums] + mask_i = strip_mask(mask_i) + for j in range(det_nums): + mask = mask_i[j].astype(np.uint8) + score = float(bboxes[k][1]) + label = int(bboxes[k][0]) + k = k + 1 + if label == -1: + continue + cat_id = label_to_cat_id_map[label] + rle = mask_util.encode( + np.array( + mask[:, :, None], order="F", dtype="uint8"))[0] + if six.PY3: + if 'counts' in rle: + rle['counts'] = rle['counts'].decode("utf8") + sg_res = { + 'image_id': cur_image_id, + 'category_id': cat_id, + 'segmentation': rle, + 'score': score + } + seg_res.append(sg_res) + return seg_res + + +def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map): + import pycocotools.mask as mask_util + segm_res = [] + # for each batch + segms = results['segm'].astype(np.uint8) + clsid_labels = results['cate_label'] + clsid_scores = results['cate_score'] + lengths = segms.shape[0] + im_id = int(image_id[0][0]) + if lengths == 0 or segms is None: + return None + # for each sample + for i in range(lengths - 1): + clsid = int(clsid_labels[i]) + catid = num_id_to_cat_id_map[clsid] + score = float(clsid_scores[i]) + mask = segms[i] + segm = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'image_id': im_id, + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res + + +def get_keypoint_res(results, im_id): + anns = [] + preds = results['keypoint'] + for idx in range(im_id.shape[0]): + image_id = im_id[idx].item() + kpts, scores = preds[idx] + for kpt, score in zip(kpts, scores): + kpt = kpt.flatten() + ann = { + 'image_id': image_id, + 'category_id': 1, # XXX hard code + 'keypoints': kpt.tolist(), + 'score': float(score) + } + x = kpt[0::3] + y = kpt[1::3] + x0, x1, y0, y1 = np.min(x).item(), np.max(x).item(), np.min( + y).item(), np.max(y).item() + ann['area'] = (x1 - x0) * (y1 - y0) + ann['bbox'] = [x0, y0, x1 - x0, y1 - y0] + anns.append(ann) + return anns diff --git a/fastdeploy/vision/evaluation/utils/map_utils.py b/fastdeploy/vision/evaluation/utils/map_utils.py new file mode 100644 index 0000000000..2ccd691913 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/map_utils.py @@ -0,0 +1,27 @@ +from __future__ import absolute_import + +import os + + +def draw_pr_curve(precision, + recall, + iou=0.5, + out_dir='pr_curve', + file_name='precision_recall_curve.jpg'): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + output_path = os.path.join(out_dir, file_name) + try: + import matplotlib.pyplot as plt + except Exception as e: + # logger.error('Matplotlib not found, plaese install matplotlib.' + # 'for example: `pip install matplotlib`.') + raise e + plt.cla() + plt.figure('P-R Curve') + plt.title('Precision/Recall Curve(IoU={})'.format(iou)) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.grid(True) + plt.plot(recall, precision) + plt.savefig(output_path) diff --git a/fastdeploy/vision/evaluation/utils/metrics.py b/fastdeploy/vision/evaluation/utils/metrics.py new file mode 100644 index 0000000000..37d3ea1b98 --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/metrics.py @@ -0,0 +1,75 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import sys +from collections import OrderedDict +from .coco_utils import get_infer_results, cocoapi_eval + + +class COCOMetric(object): + def __init__(self, coco_gt, **kwargs): + self.clsid2catid = { + i: cat['id'] + for i, cat in enumerate(coco_gt.loadCats(coco_gt.getCatIds())) + } + self.coco_gt = coco_gt + self.classwise = kwargs.get('classwise', False) + self.bias = 0 + self.reset() + + def reset(self): + # only bbox and mask evaluation support currently + self.details = { + 'gt': copy.deepcopy(self.coco_gt.dataset), + 'bbox': [], + 'mask': [] + } + self.eval_stats = {} + + def update(self, im_id, outputs): + outs = {} + # outputs Tensor -> numpy.ndarray + for k, v in outputs.items(): + outs[k] = v + + outs['im_id'] = im_id + infer_results = get_infer_results( + outs, self.clsid2catid, bias=self.bias) + self.details['bbox'] += infer_results[ + 'bbox'] if 'bbox' in infer_results else [] + self.details['mask'] += infer_results[ + 'mask'] if 'mask' in infer_results else [] + + def accumulate(self): + if len(self.details['bbox']) > 0: + bbox_stats = cocoapi_eval( + copy.deepcopy(self.details['bbox']), + 'bbox', + coco_gt=self.coco_gt, + classwise=self.classwise) + self.eval_stats['bbox'] = bbox_stats + sys.stdout.flush() + + if len(self.details['mask']) > 0: + seg_stats = cocoapi_eval( + copy.deepcopy(self.details['mask']), + 'segm', + coco_gt=self.coco_gt, + classwise=self.classwise) + self.eval_stats['mask'] = seg_stats + sys.stdout.flush() + + def log(self): + pass + + def get(self): + if 'bbox' not in self.eval_stats: + return {'bbox_mmap': 0.} + if 'mask' in self.eval_stats: + return OrderedDict( + zip(['bbox_mmap', 'segm_mmap'], + [self.eval_stats['bbox'][0], self.eval_stats['mask'][0]])) + else: + return {'bbox_mmap': self.eval_stats['bbox'][0]} diff --git a/fastdeploy/vision/evaluation/utils/util.py b/fastdeploy/vision/evaluation/utils/util.py new file mode 100644 index 0000000000..7cae5e5aea --- /dev/null +++ b/fastdeploy/vision/evaluation/utils/util.py @@ -0,0 +1,20 @@ +import platform +import multiprocessing as mp + + +def is_pic(img_name): + valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'] + suffix = img_name.split('.')[-1] + if suffix not in valid_suffix: + return False + return True + + +def get_num_workers(num_workers): + if not platform.system() == 'Linux': + # Dataloader with multi-process model is not supported + # on MacOS and Windows currently. + return 0 + if num_workers == 'auto': + num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 2 else 2 + return num_workers From 1b4f7090677861e0f750d866982512faec2d378b Mon Sep 17 00:00:00 2001 From: felixhjh <852142024@qq.com> Date: Tue, 26 Jul 2022 07:43:11 +0000 Subject: [PATCH 2/7] Add license --- fastdeploy/vision/evaluation/utils/__init__.py | 14 ++++++++++++++ fastdeploy/vision/evaluation/utils/coco.py | 4 ++-- fastdeploy/vision/evaluation/utils/coco_utils.py | 2 +- fastdeploy/vision/evaluation/utils/fd_logging.py | 4 ++-- .../vision/evaluation/utils/json_results.py | 3 ++- fastdeploy/vision/evaluation/utils/map_utils.py | 15 ++++++++++++++- fastdeploy/vision/evaluation/utils/metrics.py | 14 ++++++++++++++ fastdeploy/vision/evaluation/utils/util.py | 14 ++++++++++++++ 8 files changed, 63 insertions(+), 7 deletions(-) diff --git a/fastdeploy/vision/evaluation/utils/__init__.py b/fastdeploy/vision/evaluation/utils/__init__.py index 4536b491d8..dfcb419bad 100644 --- a/fastdeploy/vision/evaluation/utils/__init__.py +++ b/fastdeploy/vision/evaluation/utils/__init__.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from . import fd_logging from .util import * from .metrics import * diff --git a/fastdeploy/vision/evaluation/utils/coco.py b/fastdeploy/vision/evaluation/utils/coco.py index 079cde0c22..c675790557 100644 --- a/fastdeploy/vision/evaluation/utils/coco.py +++ b/fastdeploy/vision/evaluation/utils/coco.py @@ -1,10 +1,10 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/fastdeploy/vision/evaluation/utils/coco_utils.py b/fastdeploy/vision/evaluation/utils/coco_utils.py index 66282148b1..9d551f253f 100644 --- a/fastdeploy/vision/evaluation/utils/coco_utils.py +++ b/fastdeploy/vision/evaluation/utils/coco_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/fastdeploy/vision/evaluation/utils/fd_logging.py b/fastdeploy/vision/evaluation/utils/fd_logging.py index 02c0b5d024..12091a4f75 100644 --- a/fastdeploy/vision/evaluation/utils/fd_logging.py +++ b/fastdeploy/vision/evaluation/utils/fd_logging.py @@ -1,10 +1,10 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/fastdeploy/vision/evaluation/utils/json_results.py b/fastdeploy/vision/evaluation/utils/json_results.py index c144111b5b..b2e816025b 100644 --- a/fastdeploy/vision/evaluation/utils/json_results.py +++ b/fastdeploy/vision/evaluation/utils/json_results.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import six import numpy as np diff --git a/fastdeploy/vision/evaluation/utils/map_utils.py b/fastdeploy/vision/evaluation/utils/map_utils.py index 2ccd691913..12ea43d3c6 100644 --- a/fastdeploy/vision/evaluation/utils/map_utils.py +++ b/fastdeploy/vision/evaluation/utils/map_utils.py @@ -1,5 +1,18 @@ -from __future__ import absolute_import +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import import os diff --git a/fastdeploy/vision/evaluation/utils/metrics.py b/fastdeploy/vision/evaluation/utils/metrics.py index 37d3ea1b98..ece5036937 100644 --- a/fastdeploy/vision/evaluation/utils/metrics.py +++ b/fastdeploy/vision/evaluation/utils/metrics.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/fastdeploy/vision/evaluation/utils/util.py b/fastdeploy/vision/evaluation/utils/util.py index 7cae5e5aea..700ac2cbed 100644 --- a/fastdeploy/vision/evaluation/utils/util.py +++ b/fastdeploy/vision/evaluation/utils/util.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import platform import multiprocessing as mp From 9d8a94b2a727cab7ad948943252325afb7914d88 Mon Sep 17 00:00:00 2001 From: felixhjh <852142024@qq.com> Date: Wed, 27 Jul 2022 07:55:38 +0000 Subject: [PATCH 3/7] Fix python import problem --- fastdeploy/vision/evaluation/utils/__init__.py | 1 - fastdeploy/vision/evaluation/utils/coco.py | 1 - fastdeploy/vision/evaluation/utils/coco_utils.py | 5 +++-- requirements.txt | 1 + 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fastdeploy/vision/evaluation/utils/__init__.py b/fastdeploy/vision/evaluation/utils/__init__.py index dfcb419bad..afa10c0e85 100644 --- a/fastdeploy/vision/evaluation/utils/__init__.py +++ b/fastdeploy/vision/evaluation/utils/__init__.py @@ -19,4 +19,3 @@ from .map_utils import * from .coco_utils import * from .coco import * -from .cityscapes import Cityscapes diff --git a/fastdeploy/vision/evaluation/utils/coco.py b/fastdeploy/vision/evaluation/utils/coco.py index c675790557..70a9714c28 100644 --- a/fastdeploy/vision/evaluation/utils/coco.py +++ b/fastdeploy/vision/evaluation/utils/coco.py @@ -15,7 +15,6 @@ from __future__ import absolute_import import copy import os.path as osp -import six import sys import numpy as np from . import fd_logging as logging diff --git a/fastdeploy/vision/evaluation/utils/coco_utils.py b/fastdeploy/vision/evaluation/utils/coco_utils.py index 9d551f253f..18a25aa0e8 100644 --- a/fastdeploy/vision/evaluation/utils/coco_utils.py +++ b/fastdeploy/vision/evaluation/utils/coco_utils.py @@ -18,10 +18,9 @@ import sys import numpy as np -import itertools from .map_utils import draw_pr_curve from .json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res -import logging as logging +import .fd_logging as logging import copy @@ -202,6 +201,8 @@ def cocoapi_eval(anns, file_name='{}_precision_recall_curve.jpg'.format(nm["name"])) num_columns = min(6, len(results_per_category) * 2) + + import itertools results_flatten = list(itertools.chain(*results_per_category)) headers = ['category', 'AP'] * (num_columns // 2) results_2d = itertools.zip_longest( diff --git a/requirements.txt b/requirements.txt index 276109fc15..8440e24f85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ opencv-python tqdm +numpy From 6b90b67a8293305af63f304db78060a6022cb1d8 Mon Sep 17 00:00:00 2001 From: felixhjh <852142024@qq.com> Date: Wed, 27 Jul 2022 08:27:24 +0000 Subject: [PATCH 4/7] Modify requirement.txt --- fastdeploy/vision/evaluation/utils/coco_utils.py | 2 +- requirements.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/vision/evaluation/utils/coco_utils.py b/fastdeploy/vision/evaluation/utils/coco_utils.py index 18a25aa0e8..ac7ba3333d 100644 --- a/fastdeploy/vision/evaluation/utils/coco_utils.py +++ b/fastdeploy/vision/evaluation/utils/coco_utils.py @@ -20,7 +20,7 @@ import numpy as np from .map_utils import draw_pr_curve from .json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res -import .fd_logging as logging +from . import fd_logging as logging import copy diff --git a/requirements.txt b/requirements.txt index 8440e24f85..05a82366e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ opencv-python tqdm numpy +pycocotools From 19f23e6548f5d2d96d7edf2fd463ba532799047a Mon Sep 17 00:00:00 2001 From: felixhjh <852142024@qq.com> Date: Wed, 27 Jul 2022 08:29:34 +0000 Subject: [PATCH 5/7] Add requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 05a82366e0..7e18ca0346 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ opencv-python tqdm numpy pycocotools +colorama From d101ebe80f57f105ce02e54a19d20be37b078f87 Mon Sep 17 00:00:00 2001 From: felixhjh <852142024@qq.com> Date: Thu, 28 Jul 2022 07:23:24 +0000 Subject: [PATCH 6/7] Evaluation support model containing nms --- fastdeploy/vision/evaluation/detection.py | 28 +++++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/fastdeploy/vision/evaluation/detection.py b/fastdeploy/vision/evaluation/detection.py index 4aaaaaaa56..43bd7624e7 100644 --- a/fastdeploy/vision/evaluation/detection.py +++ b/fastdeploy/vision/evaluation/detection.py @@ -20,20 +20,25 @@ import copy import collections +nms_include = ['PaddleDetection/PPYOLOE'] + def eval_detection(model, - conf_threshold, - nms_iou_threshold, data_dir, ann_file, + conf_threshold=None, + nms_iou_threshold=None, plot=False): - assert isinstance(conf_threshold, ( - float, int - )), "The conf_threshold:{} need to be int or float".format(conf_threshold) - assert isinstance(nms_iou_threshold, ( - float, - int)), "The nms_iou_threshold:{} need to be int or float".format( - nms_iou_threshold) + if conf_threshold is not None or nms_iou_threshold is not None: + assert conf_threshold is not None and nms_iou_threshold is not None, "The conf_threshold and nms_iou_threshold should be setted at the same time" + assert isinstance(conf_threshold, ( + float, + int)), "The conf_threshold:{} need to be int or float".format( + conf_threshold) + assert isinstance(nms_iou_threshold, ( + float, + int)), "The nms_iou_threshold:{} need to be int or float".format( + nms_iou_threshold) eval_dataset = CocoDetection( data_dir=data_dir, ann_file=ann_file, shuffle=False) all_image_info = eval_dataset.file_list @@ -49,7 +54,10 @@ def eval_detection(model, image_num, desc="Inference Progress")): im = cv2.imread(image_info["image"]) im_id = image_info["im_id"] - result = model.predict(im, conf_threshold, nms_iou_threshold) + if conf_threshold is None and nms_iou_threshold is None: + result = model.predict(im) + else: + result = model.predict(im, conf_threshold, nms_iou_threshold) pred = { 'bbox': [[c] + [s] + b From abfc710de55e00e5a64fc80f256436e70573df05 Mon Sep 17 00:00:00 2001 From: felixhjh <852142024@qq.com> Date: Thu, 28 Jul 2022 08:02:38 +0000 Subject: [PATCH 7/7] Delete useless code --- fastdeploy/vision/evaluation/detection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastdeploy/vision/evaluation/detection.py b/fastdeploy/vision/evaluation/detection.py index 43bd7624e7..cd09046f7e 100644 --- a/fastdeploy/vision/evaluation/detection.py +++ b/fastdeploy/vision/evaluation/detection.py @@ -20,8 +20,6 @@ import copy import collections -nms_include = ['PaddleDetection/PPYOLOE'] - def eval_detection(model, data_dir,