diff --git a/examples/tsm/kinetics_dataset.py b/examples/tsm/kinetics_dataset.py index c8570018cfbcf..123d89814a8c6 100644 --- a/examples/tsm/kinetics_dataset.py +++ b/examples/tsm/kinetics_dataset.py @@ -100,19 +100,12 @@ def __len__(self): def __getitem__(self, idx): pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx]) - try: - if six.PY2: - data = pickle.load(open(pickle_path, 'rb')) - else: - data = pickle.load(open(pickle_path, 'rb'), encoding='bytes') - - vid, label, frames = data - if len(frames) < 1: - logger.error("{} contains no frame".format(pickle_path)) - sys.exit(-1) - except Exception as e: - logger.error("Load {} failed: {}".format(pickle_path, e)) - sys.exit(-1) + if six.PY2: + data = pickle.load(open(pickle_path, 'rb')) + else: + data = pickle.load(open(pickle_path, 'rb'), encoding='bytes') + + vid, label, frames = data if self.label_list is not None: label = self.label_list.index(label) diff --git a/hapi/datasets/coco.py b/hapi/datasets/coco.py index f1ab97281a6e0..50d31cff06692 100644 --- a/hapi/datasets/coco.py +++ b/hapi/datasets/coco.py @@ -18,7 +18,6 @@ import os import cv2 import numpy as np -from pycocotools.coco import COCO from paddle.io import Dataset @@ -91,6 +90,7 @@ def __init__(self, self._load_roidb_and_cname2cid() def _load_roidb_and_cname2cid(self): + from pycocotools.coco import COCO assert self._anno_path.endswith('.json'), \ 'invalid coco annotation file: ' + anno_path coco = COCO(self._anno_path) diff --git a/hapi/model.py b/hapi/model.py index f4e6744df5107..ed891f58a95f3 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -798,12 +798,12 @@ def _check_match(key, param): "{} receives a shape {}, but the expected shape is {}.". format(key, list(state.shape), list(param.shape))) return param, state - - def _strip_postfix(path): - path, ext = os.path.splitext(path) - assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ - "Unknown postfix {} from weights".format(ext) - return path + + def _strip_postfix(path): + path, ext = os.path.splitext(path) + assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ + "Unknown postfix {} from weights".format(ext) + return path path = _strip_postfix(path) param_state = _load_state_from_path(path + ".pdparams")