Skip to content

Commit

Permalink
Revert "merge dataset test (#46)"
Browse files Browse the repository at this point in the history
This reverts commit ac197fe.
  • Loading branch information
rogerwwww authored Dec 18, 2022
1 parent ac197fe commit ec8fe51
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 133 deletions.
67 changes: 1 addition & 66 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
branches: [ main ]

jobs:
linux:
build:

runs-on: ubuntu-latest
strategy:
Expand Down Expand Up @@ -40,68 +40,3 @@ jobs:
pytest --cov=pygmtools --cov-report=xml
- name: Upload to codecov
uses: codecov/codecov-action@v3

# macos:
#
# runs-on: macos-latest
# strategy:
# fail-fast: false
# matrix:
# python-version: [ "3.7", "3.8", "3.9" ]
#
# steps:
# - uses: actions/checkout@v2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v2
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# brew reinstall libomp
# brew --prefix libomp
# export LIBRARY_PATH=/usr/local/opt
# python -m pip install --upgrade pip
# python -m pip install flake8 pytest-cov
# pip install -r tests/requirements.txt
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Test with pytest
# run: |
# pytest --cov=pygmtools --cov-report=xml
# - name: Upload to codecov
# uses: codecov/codecov-action@v3
#
# windows:
#
# runs-on: windows-latest
# strategy:
# fail-fast: false
# matrix:
# python-version: [ "3.8", "3.9" ]
#
# steps:
# - uses: actions/checkout@v2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v2
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# python -m pip install flake8 pytest-cov
# python -m pip install -r tests\requirements.txt
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Test with pytest
# run: |
# pytest --cov=pygmtools --cov-report=xml
# - name: Upload to codecov
# uses: codecov/codecov-action@v3
33 changes: 6 additions & 27 deletions pygmtools/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_data(self, ids, test=False, shuffle=True):
:param ids: list of image ID, usually in ``train.json`` or ``test.json``
:param test: bool, whether the fetched data is used for test; if true, this function will not return ground truth
:param shuffle: bool, whether to shuffle the order of keypoints
:param shuffle: bool, whether to shuffle the order of keypoints; valid only when the class param ``sets`` is ``'train'``
:return:
**data_list**: list of data, like ``[{'img': np.array, 'kpts': coordinates of kpts}, ...]``
Expand All @@ -103,7 +103,7 @@ def get_data(self, ids, test=False, shuffle=True):
obj_dict['kpts'] = self.data_dict[keys]['kpts']
obj_dict['cls'] = self.data_dict[keys]['cls']
obj_dict['univ_size'] = self.data_dict[keys]['univ_size']
if shuffle:
if shuffle and self.sets != 'test':
random.shuffle(obj_dict['kpts'])
data_list.append(obj_dict)

Expand Down Expand Up @@ -197,7 +197,7 @@ def rand_get_data(self, cls=None, num=2, test=False, shuffle=True):
:param cls: int or str, class of expected data. None for random class
:param num: int, number of images; for example, 2 for 2GM
:param test: bool, whether the fetched data is used for test; if true, this function will not return ground truth
:param shuffle: bool, whether to shuffle the order of keypoints
:param shuffle: bool, whether to shuffle the order of keypoints; valid only when the class param ``sets`` is ``'train'``
:return:
**data_list**: list of data, like ``[{'img': np.array, 'kpts': coordinates of kpts}, ...]``
Expand Down Expand Up @@ -365,24 +365,14 @@ def compute_img_num(self, classes):

return num_list

def eval(self, prediction, classes, verbose=False, rm_gt_cache=True):
def eval(self, prediction, classes, verbose=False):
r"""
Evaluate test results and compute matching accuracy and coverage.
:param prediction: list, prediction result, like ``[{'ids': (id1, id2), 'cls': cls, 'permmat': np.array or scipy.sparse}, ...]``
:param classes: list of evaluated classes
:param verbose: bool, whether to print the result
:param rm_gt_cache: bool, whether to remove ground truth cache
:return: evaluation result in each class and their averages, including p, r, f1 and their standard deviation and coverage
.. note::
If there are duplicate data pair in ``prediction``, this function will only evaluate the first pair and
expect that this pair is also the first fetched pair. Therefore, it is recommended that ``prediction`` is
built in an ordered manner, and not shuffled.
.. note::
Ground truth cache is saved when data pairs are fetched, and should be removed after evaluation. Make sure
all data pairs are evaluated at once, i.e., ``prediction`` should contain all fetched data pairs.
"""

with open(self.data_list_path) as f1:
Expand Down Expand Up @@ -482,8 +472,6 @@ def eval(self, prediction, classes, verbose=False, rm_gt_cache=True):
result['mean']['recall'], result['mean']['recall_std'],
result['mean']['f1'], result['mean']['f1_std']
)))
if rm_gt_cache:
self.rm_gt_cache(last_epoch=False)
return result

def eval_cls(self, prediction, cls, verbose=False):
Expand All @@ -494,15 +482,6 @@ def eval_cls(self, prediction, cls, verbose=False):
:param cls: str, evaluated class
:param verbose: bool, whether to print the result
:return: evaluation result on the specified class, including p, r, f1 and their standard deviation and coverage
.. note::
If there are duplicate data pair in ``prediction``, this function will only evaluate the first pair and
expect that this pair is also the first fetched pair. Therefore, it is recommended that ``prediction`` is
built in an ordered manner, and not shuffled. Same as the function ``eval``.
.. note::
This function will not automatically remove ground truth cache. However, you can still mannually call the
class function ``rm_gt_cache`` to remove groud truth cache after evaluation.
"""

with open(self.data_list_path) as f1:
Expand Down Expand Up @@ -568,9 +547,9 @@ class function ``rm_gt_cache`` to remove groud truth cache after evaluation.

def rm_gt_cache(self, last_epoch=False):
r"""
Remove ground truth cache. It is recommended to call this function after evaluation.
Remove ground truth cache. It is recommended to call this function after evaluation in each epoch.
:param last_epoch: bool, whether this epoch is last epoch; if true, the directory of cache will also be removed, and no more data should be evaluated
:param last_epoch: Boolean variable, whether this epoch is last epoch; if true, the directory of cache will also be removed.
"""
if os.path.exists(self.gt_cache_path):
shutil.rmtree(self.gt_cache_path)
Expand Down
79 changes: 53 additions & 26 deletions pygmtools/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ def __init__(self, sets, obj_resize, **ds_dict):
SPLIT_OFFSET = dataset_cfg.WillowObject.SPLIT_OFFSET
TRAIN_SAME_AS_TEST = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST
RAND_OUTLIER = dataset_cfg.WillowObject.RAND_OUTLIER
URL = 'http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip'
if len(ds_dict.keys()) > 0:
if 'CLASSES' in ds_dict.keys():
CLASSES = ds_dict['CLASSES']
Expand All @@ -479,13 +478,11 @@ def __init__(self, sets, obj_resize, **ds_dict):
TRAIN_SAME_AS_TEST = ds_dict['TRAIN_SAME_AS_TEST']
if 'RAND_OUTLIER' in ds_dict.keys():
RAND_OUTLIER = ds_dict['RAND_OUTLIER']
if 'URL' in ds_dict.keys():
URL = ds_dict['URL']

self.dataset_dir = 'data/WillowObject'
if not os.path.exists(ROOT_DIR):
assert ROOT_DIR == dataset_cfg.WillowObject.ROOT_DIR, 'you should not change ROOT_DIR unless the data have been manually downloaded'
self.download(url=URL)
self.download(url='http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip')

if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
Expand Down Expand Up @@ -623,23 +620,23 @@ def process(self):
if self.sets == 'train':
for x in range(len(self.mat_list)):
for name in self.mat_list[x]:
tmp = os.path.split(str(name))
tmp = str(name).split('/')
objID = tmp[-1].split('.')[0]
train_list.append(objID)
for x in range(len(mat_list_)):
for name in mat_list_[x]:
tmp = os.path.split(str(name))
tmp = str(name).split('/')
objID = tmp[-1].split('.')[0]
test_list.append(objID)
else:
for x in range(len(self.mat_list)):
for name in self.mat_list[x]:
tmp = os.path.split(str(name))
tmp = str(name).split('/')
objID = tmp[-1].split('.')[0]
test_list.append(objID)
for x in range(len(mat_list_)):
for name in mat_list_[x]:
tmp = os.path.split(str(name))
tmp = str(name).split('/')
objID = tmp[-1].split('.')[0]
train_list.append(objID)
str1 = json.dumps(train_list)
Expand All @@ -656,9 +653,9 @@ def process(self):

for x in range(len(data_list)):
for name in data_list[x]:
tmp = os.path.split(str(name))
tmp = str(name).split('/')
objID = tmp[-1].split('.')[0]
cls = os.path.split(tmp[0])[-1]
cls = tmp[3]
annotations = self.__get_anno_dict(name, cls)
data_dict[objID] = annotations

Expand Down Expand Up @@ -867,9 +864,9 @@ def process(self):

for x in range(len(data_list)):
for name in data_list[x]:
tmp = os.path.split(str(name))
tmp = str(name).split('/')
objID = tmp[-1].split('.')[0]
cls = os.path.split(tmp[0])[-1]
cls = tmp[3]
annotations = self.__get_anno_dict(name, cls)
ID = objID + '_' + cls
data_dict[ID] = annotations
Expand Down Expand Up @@ -1001,7 +998,6 @@ def __init__(self, sets, obj_resize, **ds_dict):
CLASSES = dataset_cfg.IMC_PT_SparseGM.CLASSES
ROOT_DIR_NPZ = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ
ROOT_DIR_IMG = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG
URL = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1Po9pRMWXTqKK2ABPpVmkcsOq-6K_2v-B'
if len(ds_dict.keys()) > 0:
if 'MAX_KPT_NUM' in ds_dict.keys():
MAX_KPT_NUM = ds_dict['MAX_KPT_NUM']
Expand All @@ -1011,20 +1007,17 @@ def __init__(self, sets, obj_resize, **ds_dict):
ROOT_DIR_NPZ = ds_dict['ROOT_DIR_NPZ']
if 'ROOT_DIR_IMG' in ds_dict.keys():
ROOT_DIR_IMG = ds_dict['ROOT_DIR_IMG']
if 'URL' in ds_dict.keys():
URL = ds_dict['URL']

self.dataset_dir = 'data/IMC-PT-SparseGM'
if not os.path.exists(ROOT_DIR_IMG):
assert ROOT_DIR_IMG == dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG, 'you should not change ROOT_DIR_IMG or ROOT_DIR_NPZ unless the data have been manually downloaded'
assert ROOT_DIR_NPZ == dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ, 'you should not change ROOT_DIR_IMG or ROOT_DIR_NPZ unless the data have been manually downloaded'
self.download(url=URL)
self.download(url='https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1Po9pRMWXTqKK2ABPpVmkcsOq-6K_2v-B')

if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
self.sets = sets
self.classes = CLASSES[sets]
self.class_dict = CLASSES
self.max_kpt_num = MAX_KPT_NUM
self.suffix = 'imcpt-' + str(MAX_KPT_NUM)

Expand Down Expand Up @@ -1088,9 +1081,9 @@ def process(self):

if not os.path.exists(img_file):
total_cls = []
for cls in self.class_dict['train']:
for cls in dataset_cfg.IMC_PT_SparseGM.CLASSES['train']:
total_cls.append(cls)
for cls in self.class_dict['test']:
for cls in dataset_cfg.IMC_PT_SparseGM.CLASSES['test']:
total_cls.append(cls)

total_img_lists = [np.load(self.root_path_npz / cls / 'img_info.npz')['img_name'].tolist()
Expand Down Expand Up @@ -1163,34 +1156,35 @@ class CUB2011:
:param sets: str, problem set, ``'train'`` for training set and ``'test'`` for testing set
:param obj_resize: tuple, resized image size
:param ds_dict: settings of dataset, containing at most 1 params(key) for CUB2011:
:param ds_dict: settings of dataset, containing at most 2 params(keys) for CUB2011:
* **ROOT_DIR**: str, directory of data
* **CLS_SPLIT**: str, ``'ori'`` (original split), ``'sup'`` (super class) or ``'all'`` (all birds as one class)
"""
def __init__(self, sets, obj_resize, **ds_dict):
CLS_SPLIT = dataset_cfg.CUB2011.CLASS_SPLIT
ROOT_DIR = dataset_cfg.CUB2011.ROOT_DIR
URL = 'https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
if len(ds_dict.keys()) > 0:
if 'CLS_SPLIT' in ds_dict.keys():
CLS_SPLIT = ds_dict['CLS_SPLIT']
if 'ROOT_DIR' in ds_dict.keys():
ROOT_DIR = ds_dict['ROOT_DIR']
if 'URL' in ds_dict.keys():
URL = ds_dict['URL']

self.set_data = {'train': [], 'test': []}
self.classes = []

self._set_pairs = {}
self._set_mask = {}
self.cls_split = CLS_SPLIT
self.suffix = 'cub2011'
self.suffix = 'cub2011-' + CLS_SPLIT

self.rootpath = ROOT_DIR

self.dataset_dir = 'data/CUB_200_2011'
if not os.path.exists(ROOT_DIR):
assert ROOT_DIR == dataset_cfg.CUB2011.ROOT_DIR, 'you should not change ROOT_DIR unless the data have been manually downloaded'
self.download(url=URL)
self.download(url='https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45')

if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
Expand Down Expand Up @@ -1225,12 +1219,45 @@ def __init__(self, sets, obj_resize, **ds_dict):
test_set.append(img_idx)
self.set_data['train'].append(train_set)
self.set_data['test'].append(test_set)
elif self.cls_split == 'sup':
super_classes = [v.split('_')[-1] for v in classes.values()]
self.classes = list(set(super_classes))
for cls in self.classes:
self.set_data['train'].append([])
self.set_data['test'].append([])
for class_idx in sorted(classes):
supcls_idx = self.classes.index(classes[class_idx].split('_')[-1])
train_set = []
test_set = []
for img_idx in class2img[class_idx]:
if train_split[img_idx] == '1':
train_set.append(img_idx)
else:
test_set.append(img_idx)
self.set_data['train'][supcls_idx] += train_set
self.set_data['test'][supcls_idx] += test_set
elif self.cls_split == 'all':
self.classes.append('cub2011')
self.set_data['train'].append([])
self.set_data['test'].append([])
for class_idx in sorted(classes):
train_set = []
test_set = []
for img_idx in class2img[class_idx]:
if train_split[img_idx] == '1':
train_set.append(img_idx)
else:
test_set.append(img_idx)
self.set_data['train'][0] += train_set
self.set_data['test'][0] += test_set
else:
raise ValueError('Unknown CUB2011.CLASS_SPLIT {}'.format(self.cls_split))
self.sets = sets
self.obj_resize = obj_resize

self.process()

def download(self, url=None, retries=50):
def download(self, url=None, retries=10):
r"""
Automatically download CUB2011 dataset.
Expand Down
2 changes: 1 addition & 1 deletion pygmtools/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# CUB2011 dataset
__C.CUB2011 = edict()
__C.CUB2011.ROOT_DIR = 'data/CUB_200_2011'
__C.CUB2011.CLASS_SPLIT = 'ori' # choose from 'ori' (original split), 'sup' (super class) or 'all' (all birds as one class), only support 'ori'
__C.CUB2011.CLASS_SPLIT = 'ori' # choose from 'ori' (original split), 'sup' (super class) or 'all' (all birds as one class)

# SWPair-71 Dataset
__C.SPair = edict()
Expand Down
Loading

0 comments on commit ec8fe51

Please sign in to comment.