Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update code base for Chainer v4 #43

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# OSX specific
.DS_Store

# dirs
data
results
ompose

# OSX specific
.DS_Store

# sftp settings
sftp-config.json
Expand Down
64 changes: 13 additions & 51 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,73 +1,35 @@
# DeepPose

NOTE: This is not official implementation. Original paper is [DeepPose: Human Pose Estimation via Deep Neural Networks](http://arxiv.org/abs/1312.4659).
**NOTE: This is NOT the official implementation.**

This is an unofficial implementation of [DeepPose: Human Pose Estimation via Deep Neural Networks](http://arxiv.org/abs/1312.4659).

# Requirements

- Python 3.5.1+
- [Chainer 1.13.0+](https://github.com/pfnet/chainer)
- numpy 1.9+
- scikit-image 0.11.3+
- OpenCV 3.1.0+

I strongly recommend to use Anaconda environment. This repo may be able to be used in Python 2.7 environment, but I haven't tested.

## Installation of dependencies

```
pip install chainer
pip install numpy
pip install scikit-image
# for python3
conda install -c https://conda.binstar.org/menpo opencv3
# for python2
conda install opencv
```
- [Chainer](https://chainer.org/)>=4.2.0
- [CuPy](https://cupy.chainer.org/)>=4.2.0
- [ChainerCV](http://chainercv.readthedocs.io/en/stable/index.html)>=0.10.0
- [NumPy](http://numpy.org/)>=1.14.5
- [opencv-python](https://pypi.org/project/opencv-python/)==3.4.5.20

# Dataset preparation
# Download Datasets

```
bash datasets/download.sh
python datasets/flic_dataset.py
python datasets/lsp_dataset.py
python datasets/mpii_dataset.py
```

- [FLIC-full dataset](http://vision.grasp.upenn.edu/cgi-bin/index.php?n=VideoLearning.FLIC)
- [LSP Extended dataset](http://www.comp.leeds.ac.uk/mat4saj/lspet_dataset.zip)
- **MPII dataset**
- [Annotation](http://datasets.d2.mpi-inf.mpg.de/leonid14cvpr/mpii_human_pose_v1_u12_1.tar.gz)
- [Images](http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1.tar.gz)

## MPII Dataset

- [MPII Human Pose Dataset](http://human-pose.mpi-inf.mpg.de/#download)
- training images: 18079, test images: 6908
- test images don't have any annotations
- so we split trining imges into training/test joint set
- each joint set has
- training joint set: 17928, test joint set: 1991
- [FLIC-full dataset](https://bensapp.github.io/flic-dataset.html)
- [LSP Extended dataset](http://sam.johnson.io/research/lspet.html)

# Start training
# How to start training

Starting with the prepared shells is the easiest way. If you want to run `train.py` with your own settings, please check the options first by `python scripts/train.py --help` and modify one of the following shells to customize training settings.

## For FLIC Dataset

```
bash shells/train_flic.sh
```

## For LSP Dataset

```
bash shells/train_lsp.sh
```

## For MPII Dataset

```
bash shells/train_mpii.sh
python scripts/train.py -o results/$(date "+%Y-%m-%d_%H-%M-%S")
```

### GPU memory requirement
Expand Down
38 changes: 0 additions & 38 deletions datasets/download.sh

This file was deleted.

81 changes: 0 additions & 81 deletions datasets/flic_dataset.py

This file was deleted.

3 changes: 3 additions & 0 deletions deeppose/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from deeppose import datasets
from deeppose import models
from deeppose import utils
Empty file added deeppose/datasets/__init__.py
Empty file.
79 changes: 79 additions & 0 deletions deeppose/datasets/flic_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2016 Shunta Saito

import os
import io
import zipfile

import cv2
import numpy as np
from chainercv.chainer_experimental.datasets import sliceable
from scipy.io import loadmat
from deeppose.utils import flic_utils
import threading


class FLICDataset(sliceable.GetterDataset):

def __init__(self, split='train', dataset_zip_path='data/FLIC.zip'):
super().__init__()
self.dataset_zip_path = dataset_zip_path
self.zf = zipfile.ZipFile(self.dataset_zip_path)
self.zf_pid = os.getpid()
self.img_paths = [fn for fn in self.zf.namelist() if fn.endswith('.jpg')]

examples = loadmat(io.BytesIO(self.zf.read('FLIC/examples.mat')))['examples'][0]
if split == 'train':
self.examples = [e for e in examples if e['istrain'][0][0] == 1]
elif split == 'test':
self.examples = [e for e in examples if e['istest'][0][0] == 1]
else:
raise ValueError('\'split\' argument should be either \'train\' or \'test\'.')

joint_names = flic_utils.flic_joint_names
available_joints = flic_utils.flic_available_joints
self.available_joint_ids = [joint_names.index(a) for a in available_joints]

self.add_getter('img', self._get_image)
self.add_getter('point', self._get_point)
self.lock = threading.Lock()

def __len__(self):
return len(self.examples)

def __getstate__(self):
d = self.__dict__.copy()
d['zf'] = None
d['lock'] = None
return d

def __setstate__(self, state):
self.__dict__ = state
self._lock = threading.Lock()

def _get_image(self, i):
"""Extract image from the zipfile.

Returns:
img (ndarray): The shape is (C, H, W) and the channel follows RGB order (NOT BGR!).
"""
with self.lock:
if self.zf is None or self.zf_pid != os.getpid():
self.zf_pid = os.getpid()
self.zf = zipfile.ZipFile(self.dataset_zip_path)
image_data = self.zf.read('FLIC/images/{}'.format(self.examples[i][3][0]))

image_file = np.frombuffer(image_data, np.uint8)
img = cv2.imdecode(image_file, cv2.IMREAD_COLOR)
assert len(img.shape) == 3 and img.shape[2] == 3, "The image has wrong shape: {}".format(img.shape)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.asarray(img, dtype=np.float32)
img = img.transpose((2, 0, 1))

return img

def _get_point(self, i):
point = self.examples[i][2].T[self.available_joint_ids].astype(np.float32)
return point[:, ::-1] # (x, y) -> (y, x)

File renamed without changes.
File renamed without changes.
8 changes: 2 additions & 6 deletions scripts/loss.py → deeppose/functions/l2_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2016 Shunta Saito

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from chainer import reporter

import chainer


class MeanSquaredError(chainer.Function):
class L2Loss(chainer.FunctionNode):

"""Mean squared error (a.k.a. Euclidean loss) function.
"""L2 loss function.

In forward method, it calculates mean squared error between two variables
with ignoring all elements that the value of ignore_joints at the same
Expand Down
47 changes: 47 additions & 0 deletions deeppose/models/AlexNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2016 Shunta Saito

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import chainer
import chainer.functions as F
import chainer.links as L


class AlexNet(chainer.Chain):

"""Single-GPU AlexNet without partition toward the channel axis."""

insize = 227

def __init__(self):
super(Alex, self).__init__()
with self.init_scope():
self.conv1 = L.Convolution2D(None, 96, 11, stride=4)
self.conv2 = L.Convolution2D(None, 256, 5, pad=2)
self.conv3 = L.Convolution2D(None, 384, 3, pad=1)
self.conv4 = L.Convolution2D(None, 384, 3, pad=1)
self.conv5 = L.Convolution2D(None, 256, 3, pad=1)
self.fc6 = L.Linear(None, 4096)
self.fc7 = L.Linear(None, 4096)
self.fc8 = L.Linear(None, 1000)

def forward(self, x, t):
h = F.max_pooling_2d(F.local_response_normalization(
F.relu(self.conv1(x))), 3, stride=2)
h = F.max_pooling_2d(F.local_response_normalization(
F.relu(self.conv2(h))), 3, stride=2)
h = F.relu(self.conv3(h))
h = F.relu(self.conv4(h))
h = F.max_pooling_2d(F.relu(self.conv5(h)), 3, stride=2)
h = F.dropout(F.relu(self.fc6(h)))
h = F.dropout(F.relu(self.fc7(h)))
h = self.fc8(h)

loss = F.softmax_cross_entropy(h, t)
chainer.report({'loss': loss, 'accuracy': F.accuracy(h, t)}, self)
return loss
1 change: 1 addition & 0 deletions deeppose/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from deeppose.models.alexnet import AlexNet
Loading