-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
44 changed files
with
5,547 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Project Specific | ||
models/*.weights | ||
models/*.pth | ||
data/* | ||
data/*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2020 Torben Teepe | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
# Download weights for vanilla YOLOv3 | ||
wget -c https://pjreddie.com/media/files/yolov3.weights | ||
# Download weights for tiny YOLOv3 | ||
wget -c https://pjreddie.com/media/files/yolov3-tiny.weights | ||
## Download weights for backbone network | ||
#wget -c https://pjreddie.com/media/files/darknet53.conv.74 | ||
|
||
print "#############################################################" | ||
print "######## Weights for HRNet Pose Estimation need to ##########" | ||
print "######## be downloaded manually from here: ##########" | ||
print "######## https://drive.google.com/drive/folders/1nzM_OBV9LbAEA7HClC0chEyf_7ECDXYA" | ||
print "######## Files: pose_hrnet_*.pth ##########" | ||
print "#############################################################" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
yacs==0.1.8 | ||
numpy==1.19.5 | ||
torch==1.7.1 | ||
torchvision==0.8.2 | ||
matplotlib==3.3.3 | ||
tabulate==0.8.7 | ||
tensorflow==2.4.0 | ||
tensorboard==2.4.0 | ||
pillow==8.1.0 | ||
tqdm==4.56.0 | ||
opencv-python~=4.5 | ||
jupyter==1.0.0 | ||
pandas==1.1.0 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import os | ||
import argparse | ||
import torch | ||
from models.st_gcn.st_gcn import STGCNEmbedding | ||
import models.ResGCNv1 | ||
|
||
|
||
def parse_option(): | ||
parser = argparse.ArgumentParser(description="Training model on gait sequence") | ||
parser.add_argument("dataset", choices=["casia-b", "outdoor-gait", "tum-gaid"]) | ||
parser.add_argument("train_data_path", help="Path to train data CSV") | ||
parser.add_argument("--valid_data_path", help="Path to validation data CSV") | ||
parser.add_argument("--valid_split", type=float, default=0.2) | ||
|
||
parser.add_argument("--checkpoint_path", help="Path to checkpoint to resume") | ||
parser.add_argument("--weight_path", help="Path to weights for model") | ||
|
||
# Optionals | ||
parser.add_argument("--num_workers", type=int, default=8) | ||
parser.add_argument( | ||
"--gpus", default="0", help="-1 for CPU, use comma for multiple gpus" | ||
) | ||
parser.add_argument("--batch_size", type=int, default=64) | ||
parser.add_argument("--batch_size_validation", type=int, default=64) | ||
parser.add_argument("--epochs", type=int, default=500) | ||
parser.add_argument("--start_epoch", type=int, default=1) | ||
parser.add_argument("--log_interval", type=int, default=10) | ||
parser.add_argument("--save_interval", type=int, default=50, help="save frequency") | ||
parser.add_argument( | ||
"--save_best_start", type=float, default=0.3, help="save frequency" | ||
) | ||
parser.add_argument("--use_amp", action="store_true") | ||
parser.add_argument("--tune", action="store_true") | ||
parser.add_argument("--shuffle", action="store_true") | ||
parser.add_argument("--exp_name", help="Name of the experiment") | ||
|
||
parser.add_argument("--network_name", default="resgcn-n39-r4") | ||
parser.add_argument("--sequence_length", type=int, default=60) | ||
parser.add_argument("--embedding_layer_size", type=int, default=256) | ||
parser.add_argument("--temporal_kernel_size", type=int, default=9) | ||
parser.add_argument("--dropout", type=float, default=0.4) | ||
parser.add_argument("--learning_rate", type=float, default=1e-3) | ||
parser.add_argument( | ||
"--lr_decay_rate", type=float, default=0.1, help="decay rate for learning rate" | ||
) | ||
parser.add_argument("--point_noise_std", type=float, default=0.05) | ||
parser.add_argument("--joint_noise_std", type=float, default=0.1) | ||
parser.add_argument("--flip_probability", type=float, default=0.5) | ||
parser.add_argument("--mirror_probability", type=float, default=0.5) | ||
parser.add_argument("--weight_decay", type=float, default=1e-5) | ||
parser.add_argument("--use_multi_branch", action="store_true") | ||
parser.add_argument( | ||
"--temp", type=float, default=0.07, help="temperature for loss function" | ||
) | ||
opt = parser.parse_args() | ||
|
||
# Sanitize opts | ||
opt.gpus_str = opt.gpus | ||
opt.gpus = [int(gpu) for gpu in opt.gpus.split(",")] | ||
|
||
return opt | ||
|
||
|
||
def log_hyperparameter(writer, opt, accuracy, loss): | ||
writer.add_hparams( | ||
{ | ||
"batch_size": opt.batch_size, | ||
"sequence_length": opt.sequence_length, | ||
"embedding_layer_size": opt.embedding_layer_size, | ||
"dropout": opt.dropout, | ||
"learning_rate": opt.learning_rate, | ||
"lr_decay_rate": opt.lr_decay_rate, | ||
"point_noise_std": opt.point_noise_std, | ||
"weight_decay": opt.weight_decay, | ||
"temp": opt.temp, | ||
}, | ||
{ | ||
"hparam/accuracy": accuracy, | ||
"hparam/loss": loss, | ||
}, | ||
) | ||
|
||
|
||
def setup_environment(opt): | ||
# HACK: Fix tensorboard | ||
import tensorflow as tf | ||
import tensorboard as tb | ||
|
||
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus_str | ||
opt.cuda = opt.gpus[0] >= 0 | ||
torch.device("cuda" if opt.cuda else "cpu") | ||
|
||
return opt | ||
|
||
|
||
def get_model_stgcn(opt): | ||
# Model | ||
input_channels = 3 | ||
edge_importance_weighting = True | ||
graph_args = {"strategy": "spatial"} | ||
|
||
embedding_net = STGCNEmbedding( | ||
input_channels, | ||
graph_args, | ||
edge_importance_weighting=edge_importance_weighting, | ||
embedding_layer_size=opt.embedding_layer_size, | ||
temporal_kernel_size=opt.temporal_kernel_size, | ||
dropout=opt.dropout, | ||
) | ||
|
||
return embedding_net | ||
|
||
|
||
def get_model_resgcn(graph, opt): | ||
model_args = { | ||
"A": torch.tensor(graph.A, dtype=torch.float32, requires_grad=False), | ||
"num_class": opt.embedding_layer_size, | ||
"num_input": 1 if not opt.use_multi_branch else 3, | ||
"num_channel": 3 if not opt.use_multi_branch else 6, | ||
"parts": graph.parts, | ||
} | ||
return models.ResGCNv1.create(opt.network_name, **model_args) | ||
|
||
|
||
def get_trainer(model, opt, steps_per_epoch): | ||
optimizer = torch.optim.Adam( | ||
model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay | ||
) | ||
scheduler = torch.optim.lr_scheduler.OneCycleLR( | ||
optimizer, opt.learning_rate, epochs=opt.epochs, steps_per_epoch=steps_per_epoch | ||
) | ||
scaler = torch.cuda.amp.GradScaler(enabled=opt.use_amp) | ||
|
||
return optimizer, scheduler, scaler | ||
|
||
|
||
def load_checkpoint(model, optimizer, scheduler, scaler, opt): | ||
if opt.checkpoint_path is not None: | ||
checkpoint = torch.load(opt.checkpoint_path) | ||
model.load_state_dict(checkpoint["model"]) | ||
optimizer.load_state_dict(checkpoint["optimizer"]) | ||
scheduler.load_state_dict(checkpoint["scheduler"]) | ||
scaler.load_state_dict(checkpoint["scaler"]) | ||
opt.start_epoch = checkpoint["epoch"] | ||
|
||
if opt.weight_path is not None: | ||
checkpoint = torch.load(opt.weight_path) | ||
model.load_state_dict(checkpoint["model"], strict=False) | ||
|
||
|
||
def save_model(model, optimizer, scheduler, scaler, opt, epoch, save_file): | ||
print("==> Saving...") | ||
state = { | ||
"opt": opt, | ||
"model": model.state_dict(), | ||
"optimizer": optimizer.state_dict(), | ||
"scheduler": scheduler.state_dict(), | ||
"scaler": scaler.state_dict(), | ||
"epoch": epoch, | ||
} | ||
torch.save(state, save_file) | ||
del state | ||
|
||
|
||
def count_parameters(model): | ||
""" | ||
Useful function to compute number of parameters in a model. | ||
""" | ||
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .preparation import DatasetSimple, DatasetDetections | ||
from .gait import ( | ||
CasiaBPose, | ||
) | ||
|
||
|
||
def dataset_factory(name): | ||
if name == "casia-b": | ||
return CasiaBPose | ||
|
||
raise ValueError() |
Oops, something went wrong.