Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#40 from heavengate/fix_tsm_hang
Browse files Browse the repository at this point in the history
fix tsm hang
  • Loading branch information
heavengate authored Apr 15, 2020
2 parents dc2a5e5 + 0872cfa commit 5ed8fa8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 20 deletions.
19 changes: 6 additions & 13 deletions examples/tsm/kinetics_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,12 @@ def __len__(self):
def __getitem__(self, idx):
pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx])

try:
if six.PY2:
data = pickle.load(open(pickle_path, 'rb'))
else:
data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')

vid, label, frames = data
if len(frames) < 1:
logger.error("{} contains no frame".format(pickle_path))
sys.exit(-1)
except Exception as e:
logger.error("Load {} failed: {}".format(pickle_path, e))
sys.exit(-1)
if six.PY2:
data = pickle.load(open(pickle_path, 'rb'))
else:
data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')

vid, label, frames = data

if self.label_list is not None:
label = self.label_list.index(label)
Expand Down
2 changes: 1 addition & 1 deletion hapi/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import os
import cv2
import numpy as np
from pycocotools.coco import COCO

from paddle.io import Dataset

Expand Down Expand Up @@ -91,6 +90,7 @@ def __init__(self,
self._load_roidb_and_cname2cid()

def _load_roidb_and_cname2cid(self):
from pycocotools.coco import COCO
assert self._anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path
coco = COCO(self._anno_path)
Expand Down
12 changes: 6 additions & 6 deletions hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,12 +798,12 @@ def _check_match(key, param):
"{} receives a shape {}, but the expected shape is {}.".
format(key, list(state.shape), list(param.shape)))
return param, state

def _strip_postfix(path):
path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
"Unknown postfix {} from weights".format(ext)
return path
def _strip_postfix(path):
path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
"Unknown postfix {} from weights".format(ext)
return path

path = _strip_postfix(path)
param_state = _load_state_from_path(path + ".pdparams")
Expand Down

0 comments on commit 5ed8fa8

Please sign in to comment.