diff --git a/README.md b/README.md index 66256bf..59b8bb0 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,20 @@ In this repo, we show the example of model on NTU-RGB+D dataset. * pyyaml * argparse * numpy +* torch 1.7.1 # Environments We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. +``` +cd torchlight +cp torchlight/torchlight/_init__.py gpu.py io.py ../ +``` +change all "from torchlight import ..." to +"from torchlight.io import ..." Run ``` -cd torchlight, python setup.py, cd .. +cd torchlight, python setup.py install, cd .. ``` @@ -30,22 +37,24 @@ For NTU-RGB+D dataset, you can download it from [NTU-RGB+D](http://rose1.ntu.edu ``` Then, run the preprocessing program to generate the input data, which is very important. ``` -python ./data_gen/ntu_gen_preprocess.py +cd data_gen +python ntu_gen_preprocess.py ``` # Training and Testing With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, ``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml --device 0 --batch_size 4 +# only can use one gpu otherwise got the error "Caught RuntimeError in replica 0 on device 0"" Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml ``` For Cross-View, ``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml -Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xview/train_aim.yaml +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xview/train.yaml +Test: python main.py recognition -c config/as_gcn/ntu-xview/test.yaml ``` # Acknowledgement diff --git a/asgcn_3090_cuda11_1_environment.yml b/asgcn_3090_cuda11_1_environment.yml new file mode 100644 index 0000000..5c5699e --- /dev/null +++ b/asgcn_3090_cuda11_1_environment.yml @@ -0,0 +1,85 @@ +name: asgcn +channels: + - pytorch + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - blas=1.0=mkl + - ca-certificates=2021.4.13=h06a4308_1 + - certifi=2020.12.5=py36h06a4308_0 + - cffi=1.14.5=py36h261ae71_0 + - cuda90=1.0=h6433d27_0 + - cudatoolkit=10.0.130=0 + - cudnn=7.6.5=cuda10.0_0 + - cycler=0.10.0=py36_0 + - dbus=1.13.18=hb2f20db_0 + - expat=2.3.0=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.1=h36276a3_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - intel-openmp=2019.4=243 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py36h2531618_0 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.2.0=h3942068_0 + - libuuid=1.0.3=h1bed415_2 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=hb55368b_3 + - lz4-c=1.9.3=h2531618_0 + - matplotlib=3.3.2=h06a4308_0 + - matplotlib-base=3.3.2=py36h817c723_0 + - mkl=2018.0.3=1 + - mkl_fft=1.0.6=py36h7dd41cf_0 + - mkl_random=1.0.1=py36h4414c95_1 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py36hff7bd54_0 + - olefile=0.46=py36_0 + - openssl=1.1.1k=h27cfd23_0 + - pcre=8.44=he6710b0_0 + - pillow=8.1.2=py36he98fc37_0 + - pip=21.0.1=py36h06a4308_0 + - pycparser=2.20=py_2 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py36h05f1152_2 + - python=3.6.13=hdb3f193_0 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py36h06a4308_0 + - sip=4.19.8=py36hf484d3e_0 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.35.1=hdfb4753_0 + - tbb=2021.2.0=hff7bd54_0 + - tbb4py=2021.2.0=py36hff7bd54_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py36h27cfd23_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - argparse==1.4.0 + - cached-property==1.5.2 + - dataclasses==0.8 + - h5py==3.1.0 + - imageio==2.9.0 + - numpy==1.19.5 + - opencv-python==4.5.1.48 + - pyyaml==6.0 + - scikit-video==1.1.11 + - scipy==1.5.4 + - torch==1.8.1+cu111 + - torchvision==0.9.1+cu111 + - tqdm==4.60.0 + - typing-extensions==3.7.4.3 diff --git a/data/NTU-RGB+D/samples_with_missing_skeletons.txt b/data/NTU-RGB+D/samples_with_missing_skeletons.txt deleted file mode 100644 index 5ad472e..0000000 --- a/data/NTU-RGB+D/samples_with_missing_skeletons.txt +++ /dev/null @@ -1,302 +0,0 @@ -S001C002P005R002A008 -S001C002P006R001A008 -S001C003P002R001A055 -S001C003P002R002A012 -S001C003P005R002A004 -S001C003P005R002A005 -S001C003P005R002A006 -S001C003P006R002A008 -S002C002P011R002A030 -S002C003P008R001A020 -S002C003P010R002A010 -S002C003P011R002A007 -S002C003P011R002A011 -S002C003P014R002A007 -S003C001P019R001A055 -S003C002P002R002A055 -S003C002P018R002A055 -S003C003P002R001A055 -S003C003P016R001A055 -S003C003P018R002A024 -S004C002P003R001A013 -S004C002P008R001A009 -S004C002P020R001A003 -S004C002P020R001A004 -S004C002P020R001A012 -S004C002P020R001A020 -S004C002P020R001A021 -S004C002P020R001A036 -S005C002P004R001A001 -S005C002P004R001A003 -S005C002P010R001A016 -S005C002P010R001A017 -S005C002P010R001A048 -S005C002P010R001A049 -S005C002P016R001A009 -S005C002P016R001A010 -S005C002P018R001A003 -S005C002P018R001A028 -S005C002P018R001A029 -S005C003P016R002A009 -S005C003P018R002A013 -S005C003P021R002A057 -S006C001P001R002A055 -S006C002P007R001A005 -S006C002P007R001A006 -S006C002P016R001A043 -S006C002P016R001A051 -S006C002P016R001A052 -S006C002P022R001A012 -S006C002P023R001A020 -S006C002P023R001A021 -S006C002P023R001A022 -S006C002P023R001A023 -S006C002P024R001A018 -S006C002P024R001A019 -S006C003P001R002A013 -S006C003P007R002A009 -S006C003P007R002A010 -S006C003P007R002A025 -S006C003P016R001A060 -S006C003P017R001A055 -S006C003P017R002A013 -S006C003P017R002A014 -S006C003P017R002A015 -S006C003P022R002A013 -S007C001P018R002A050 -S007C001P025R002A051 -S007C001P028R001A050 -S007C001P028R001A051 -S007C001P028R001A052 -S007C002P008R002A008 -S007C002P015R002A055 -S007C002P026R001A008 -S007C002P026R001A009 -S007C002P026R001A010 -S007C002P026R001A011 -S007C002P026R001A012 -S007C002P026R001A050 -S007C002P027R001A011 -S007C002P027R001A013 -S007C002P028R002A055 -S007C003P007R001A002 -S007C003P007R001A004 -S007C003P019R001A060 -S007C003P027R002A001 -S007C003P027R002A002 -S007C003P027R002A003 -S007C003P027R002A004 -S007C003P027R002A005 -S007C003P027R002A006 -S007C003P027R002A007 -S007C003P027R002A008 -S007C003P027R002A009 -S007C003P027R002A010 -S007C003P027R002A011 -S007C003P027R002A012 -S007C003P027R002A013 -S008C002P001R001A009 -S008C002P001R001A010 -S008C002P001R001A014 -S008C002P001R001A015 -S008C002P001R001A016 -S008C002P001R001A018 -S008C002P001R001A019 -S008C002P008R002A059 -S008C002P025R001A060 -S008C002P029R001A004 -S008C002P031R001A005 -S008C002P031R001A006 -S008C002P032R001A018 -S008C002P034R001A018 -S008C002P034R001A019 -S008C002P035R001A059 -S008C002P035R002A002 -S008C002P035R002A005 -S008C003P007R001A009 -S008C003P007R001A016 -S008C003P007R001A017 -S008C003P007R001A018 -S008C003P007R001A019 -S008C003P007R001A020 -S008C003P007R001A021 -S008C003P007R001A022 -S008C003P007R001A023 -S008C003P007R001A025 -S008C003P007R001A026 -S008C003P007R001A028 -S008C003P007R001A029 -S008C003P007R002A003 -S008C003P008R002A050 -S008C003P025R002A002 -S008C003P025R002A011 -S008C003P025R002A012 -S008C003P025R002A016 -S008C003P025R002A020 -S008C003P025R002A022 -S008C003P025R002A023 -S008C003P025R002A030 -S008C003P025R002A031 -S008C003P025R002A032 -S008C003P025R002A033 -S008C003P025R002A049 -S008C003P025R002A060 -S008C003P031R001A001 -S008C003P031R002A004 -S008C003P031R002A014 -S008C003P031R002A015 -S008C003P031R002A016 -S008C003P031R002A017 -S008C003P032R002A013 -S008C003P033R002A001 -S008C003P033R002A011 -S008C003P033R002A012 -S008C003P034R002A001 -S008C003P034R002A012 -S008C003P034R002A022 -S008C003P034R002A023 -S008C003P034R002A024 -S008C003P034R002A044 -S008C003P034R002A045 -S008C003P035R002A016 -S008C003P035R002A017 -S008C003P035R002A018 -S008C003P035R002A019 -S008C003P035R002A020 -S008C003P035R002A021 -S009C002P007R001A001 -S009C002P007R001A003 -S009C002P007R001A014 -S009C002P008R001A014 -S009C002P015R002A050 -S009C002P016R001A002 -S009C002P017R001A028 -S009C002P017R001A029 -S009C003P017R002A030 -S009C003P025R002A054 -S010C001P007R002A020 -S010C002P016R002A055 -S010C002P017R001A005 -S010C002P017R001A018 -S010C002P017R001A019 -S010C002P019R001A001 -S010C002P025R001A012 -S010C003P007R002A043 -S010C003P008R002A003 -S010C003P016R001A055 -S010C003P017R002A055 -S011C001P002R001A008 -S011C001P018R002A050 -S011C002P008R002A059 -S011C002P016R002A055 -S011C002P017R001A020 -S011C002P017R001A021 -S011C002P018R002A055 -S011C002P027R001A009 -S011C002P027R001A010 -S011C002P027R001A037 -S011C003P001R001A055 -S011C003P002R001A055 -S011C003P008R002A012 -S011C003P015R001A055 -S011C003P016R001A055 -S011C003P019R001A055 -S011C003P025R001A055 -S011C003P028R002A055 -S012C001P019R001A060 -S012C001P019R002A060 -S012C002P015R001A055 -S012C002P017R002A012 -S012C002P025R001A060 -S012C003P008R001A057 -S012C003P015R001A055 -S012C003P015R002A055 -S012C003P016R001A055 -S012C003P017R002A055 -S012C003P018R001A055 -S012C003P018R001A057 -S012C003P019R002A011 -S012C003P019R002A012 -S012C003P025R001A055 -S012C003P027R001A055 -S012C003P027R002A009 -S012C003P028R001A035 -S012C003P028R002A055 -S013C001P015R001A054 -S013C001P017R002A054 -S013C001P018R001A016 -S013C001P028R001A040 -S013C002P015R001A054 -S013C002P017R002A054 -S013C002P028R001A040 -S013C003P008R002A059 -S013C003P015R001A054 -S013C003P017R002A054 -S013C003P025R002A022 -S013C003P027R001A055 -S013C003P028R001A040 -S014C001P027R002A040 -S014C002P015R001A003 -S014C002P019R001A029 -S014C002P025R002A059 -S014C002P027R002A040 -S014C002P039R001A050 -S014C003P007R002A059 -S014C003P015R002A055 -S014C003P019R002A055 -S014C003P025R001A048 -S014C003P027R002A040 -S015C001P008R002A040 -S015C001P016R001A055 -S015C001P017R001A055 -S015C001P017R002A055 -S015C002P007R001A059 -S015C002P008R001A003 -S015C002P008R001A004 -S015C002P008R002A040 -S015C002P015R001A002 -S015C002P016R001A001 -S015C002P016R002A055 -S015C003P008R002A007 -S015C003P008R002A011 -S015C003P008R002A012 -S015C003P008R002A028 -S015C003P008R002A040 -S015C003P025R002A012 -S015C003P025R002A017 -S015C003P025R002A020 -S015C003P025R002A021 -S015C003P025R002A030 -S015C003P025R002A033 -S015C003P025R002A034 -S015C003P025R002A036 -S015C003P025R002A037 -S015C003P025R002A044 -S016C001P019R002A040 -S016C001P025R001A011 -S016C001P025R001A012 -S016C001P025R001A060 -S016C001P040R001A055 -S016C001P040R002A055 -S016C002P008R001A011 -S016C002P019R002A040 -S016C002P025R002A012 -S016C003P008R001A011 -S016C003P008R002A002 -S016C003P008R002A003 -S016C003P008R002A004 -S016C003P008R002A006 -S016C003P008R002A009 -S016C003P019R002A040 -S016C003P039R002A016 -S017C001P016R002A031 -S017C002P007R001A013 -S017C002P008R001A009 -S017C002P015R001A042 -S017C002P016R002A031 -S017C002P016R002A055 -S017C003P007R002A013 -S017C003P008R001A059 -S017C003P016R002A031 -S017C003P017R001A055 -S017C003P020R001A059 diff --git a/data/readme.md b/data/readme.md deleted file mode 100644 index 777d39a..0000000 --- a/data/readme.md +++ /dev/null @@ -1 +0,0 @@ -The filepath of data (NTU-RGB+D) diff --git a/data_gen/__pycache__/__init__.cpython-36.pyc b/data_gen/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..600c7f3 Binary files /dev/null and b/data_gen/__pycache__/__init__.cpython-36.pyc differ diff --git a/data_gen/__pycache__/preprocess.cpython-36.pyc b/data_gen/__pycache__/preprocess.cpython-36.pyc new file mode 100644 index 0000000..c42e997 Binary files /dev/null and b/data_gen/__pycache__/preprocess.cpython-36.pyc differ diff --git a/data_gen/__pycache__/rotation.cpython-36.pyc b/data_gen/__pycache__/rotation.cpython-36.pyc new file mode 100644 index 0000000..722593f Binary files /dev/null and b/data_gen/__pycache__/rotation.cpython-36.pyc differ diff --git a/data_gen/gpu.py b/data_gen/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/data_gen/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/data_gen/io.py b/data_gen/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/data_gen/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/data_gen/ntu_gen_preprocess.py b/data_gen/ntu_gen_preprocess.py index 6323b30..9bc8423 100644 --- a/data_gen/ntu_gen_preprocess.py +++ b/data_gen/ntu_gen_preprocess.py @@ -140,4 +140,5 @@ def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xsub', set if not os.path.exists(out_path): os.makedirs(out_path) print(b, sn) - gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) \ No newline at end of file + #gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) + gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, set_name=sn) diff --git a/feeder/feeder.py b/feeder/feeder.py index dba96f3..7998a14 100644 --- a/feeder/feeder.py +++ b/feeder/feeder.py @@ -51,7 +51,7 @@ def load_data(self, mmap): self.data = self.data[0:100] self.sample_name = self.sample_name[0:100] - self.N, self.C, self.T, self.V, self.M = self.data.shape + self.N, self.C, self.T, self.V, self.M = self.data.shape # (40091, 3, 300, 25, 2) def __len__(self): return len(self.label) diff --git a/log/data_tree.log b/log/data_tree.log new file mode 100644 index 0000000..3bfbe88 --- /dev/null +++ b/log/data_tree.log @@ -0,0 +1,16 @@ +data +|-- NTU-RGB+D +| `-- samples_with_missing_skeletons.txt +`-- nturgb_d + |-- xsub + | |-- train_data_joint_pad.npy + | |-- train_label.pkl + | |-- val_data_joint_pad.npy + | `-- val_label.pkl + `-- xview + |-- train_data_joint_pad.npy + |-- train_label.pkl + |-- val_data_joint_pad.npy + `-- val_label.pkl + +4 directories, 9 files diff --git a/log/train_aim.log b/log/train_aim.log new file mode 100644 index 0000000..1f84ef7 --- /dev/null +++ b/log/train_aim.log @@ -0,0 +1,21 @@ +$python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 + +/root/AS-GCN/processor/io.py:34: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details. + default_arg = yaml.load(f) +/root/AS-GCN/net/utils/adj_learn.py:18: UserWarning: This overload of nonzero is deprecated: + nonzero() +Consider using one of the following signatures instead: + nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) + offdiag_indices = (ones - eye).nonzero().t() +/root/anaconda3/envs/stgcn/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported. + warnings.warn("Setting attributes on ParameterList is not supported.") +/root/AS-GCN/net/utils/adj_learn.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. + soft_max_1d = F.softmax(trans_input) + +[05.08.21|19:25:22] Parameters: +{'work_dir': './work_dir/recognition/ntu-xsub/AS_GCN', 'config': 'config/as_gcn/ntu-xsub/train_aim.yaml', 'phase': 'train', 'save_result': False, 'start_epoch': 0, 'num_epoch': 10, 'use_gpu': True, 'device': [0, 1, 2], 'log_interval': 100, 'save_interval': 1, 'eval_interval': 5, 'save_log': True, 'print_log': True, 'pavi_log': False, 'feeder': 'feeder.feeder.Feeder', 'num_worker': 4, 'train_feeder_args': {'data_path': './data/nturgb_d/xsub/train_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/train_label.pkl', 'random_move': True, 'repeat_pad': True, 'down_sample': True, 'debug': False}, 'test_feeder_args': {'data_path': './data/nturgb_d/xsub/val_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/val_label.pkl', 'random_move': False, 'repeat_pad': True, 'down_sample': True}, 'batch_size': 32, 'test_batch_size': 32, 'debug': False, 'model1': 'net.as_gcn.Model', 'model2': 'net.utils.adj_learn.AdjacencyLearn', 'model1_args': {'in_channels': 3, 'num_class': 60, 'dropout': 0.5, 'edge_importance_weighting': True, 'graph_args': {'layout': 'ntu-rgb+d', 'strategy': 'spatial', 'max_hop': 4}}, 'model2_args': {'n_in_enc': 150, 'n_hid_enc': 128, 'edge_types': 3, 'n_in_dec': 3, 'n_hid_dec': 128, 'node_num': 25}, 'weights1': None, 'weights2': None, 'ignore_weights': [], 'show_topk': [1, 5], 'base_lr1': 0.1, 'base_lr2': 0.0005, 'step': [50, 70, 90], 'optimizer': 'SGD', 'nesterov': True, 'weight_decay': 0.0001, 'max_hop_dir': 'max_hop_4', 'lamda_act': 0.5, 'lamda_act_dir': 'lamda_05'} + +[05.08.21|19:25:22] Training epoch: 0 +[05.08.21|19:25:29] Iter 0 Done. | loss2: 876.8732 | loss_nll: 832.9621 | loss_kl: 43.9111 | lr: 0.000500 +[05.08.21|19:26:56] Iter 100 Done. | loss2: 118.9051 | loss_nll: 110.3876 | loss_kl: 8.5176 | lr: 0.000500 +[05.08.21|19:28:14] Iter 200 Done. | loss2: 76.1775 | loss_nll: 71.7404 | loss_kl: 4.4371 | lr: 0.000500 diff --git a/main.py b/main.py index ee1a0a2..bd402cb 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import argparse import sys import torchlight -from torchlight import import_class +from torchlight.io import import_class if __name__ == '__main__': @@ -10,7 +10,7 @@ processors = dict() processors['recognition'] = import_class('processor.recognition.REC_Processor') - processors['demo'] = import_class('processor.demo.Demo') + #processors['demo'] = import_class('processor.demo.Demo') subparsers = parser.add_subparsers(dest='processor') for k, p in processors.items(): diff --git a/net/as_gcn.py b/net/as_gcn.py index 7468be4..0be1829 100644 --- a/net/as_gcn.py +++ b/net/as_gcn.py @@ -50,18 +50,17 @@ def __init__(self, in_channels, num_class, graph_args, self.fcn = nn.Conv2d(256, num_class, kernel_size=1) def forward(self, x, x_target, x_last, A_act, lamda_act): - N, C, T, V, M = x.size() - x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] - x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] - x = x.view(N * M, V * C, T) # [2N, 75, 300] + x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] wsx: x_recon(4,3,290,25) select the first person data? + x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] wsx: x(4,2,25,3,290) + x = x.view(N * M, V * C, T) # [2N, 75, 300]m wsx: x(8,75,290) - x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) + x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) #(2N,3,1,25) x_bn = self.data_bn(x) x_bn = x_bn.view(N, M, V, C, T) x_bn = x_bn.permute(0, 1, 3, 4, 2).contiguous() - x_bn = x_bn.view(N * M, C, T, V) + x_bn = x_bn.view(N * M, C, T, V) #2N,3,290,25 h0, _ = self.class_layer_0(x_bn, self.A * self.edge_importance[0], A_act, lamda_act) # [N, 64, 300, 25] h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] @@ -74,10 +73,10 @@ def forward(self, x, x_target, x_last, A_act, lamda_act): h7, _ = self.class_layer_7(h6, self.A * self.edge_importance[7], A_act, lamda_act) # [N, 256, 75, 25] h8, _ = self.class_layer_8(h7, self.A * self.edge_importance[8], A_act, lamda_act) # [N, 256, 75, 25] - x_class = F.avg_pool2d(h8, h8.size()[2:]) - x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) - x_class = self.fcn(x_class) - x_class = x_class.view(x_class.size(0), -1) + x_class = F.avg_pool2d(h8, h8.size()[2:]) #(8,256,1,1) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) #(4,256,1,1) + x_class = self.fcn(x_class) #(4,60,1,1) Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1)) + x_class = x_class.view(x_class.size(0), -1) #(4,60) r0, _ = self.recon_layer_0(h8, self.A*self.edge_importance_recon[0], A_act, lamda_act) # [N, 128, 75, 25] r1, _ = self.recon_layer_1(r0, self.A*self.edge_importance_recon[1], A_act, lamda_act) # [N, 128, 38, 25] @@ -85,8 +84,8 @@ def forward(self, x, x_target, x_last, A_act, lamda_act): r3, _ = self.recon_layer_3(r2, self.A*self.edge_importance_recon[3], A_act, lamda_act) # [N, 128, 10, 25] r4, _ = self.recon_layer_4(r3, self.A*self.edge_importance_recon[4], A_act, lamda_act) # [N, 128, 5, 25] r5, _ = self.recon_layer_5(r4, self.A*self.edge_importance_recon[5], A_act, lamda_act) # [N, 128, 1, 25] - r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] - pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] + r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] wsx:(8,30,1,25) + pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] wsx:(8,30,25) pred = pred.contiguous().view(-1, 3, 10, 25) x_target = x_target.permute(0,4,1,2,3).contiguous().view(-1,3,10,25) diff --git a/net/model_poseformer.py b/net/model_poseformer.py new file mode 100644 index 0000000..e38ac89 --- /dev/null +++ b/net/model_poseformer.py @@ -0,0 +1,223 @@ +## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +import math +import logging +from functools import partial +from collections import OrderedDict +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class PoseTransformer(nn.Module): + def __init__(self, num_frame=9, num_joints=25, in_chans=3, embed_dim_ratio: object = 32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, + num_class=60 + ): + """ ##########hybrid_backbone=None, representation_size=None, + Args: + num_frame (int, tuple): input frame number + num_joints (int, tuple): joints number + in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) + embed_dim_ratio (int): embedding dimension ratio + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + num_class (int): the pose action class amount 30 + """ + super().__init__() + + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio + out_dim = num_joints * 3 #### output dimension is num_joints * 3 + + ### spatial patch embedding + self.Spatial_patch_to_embedding = nn.Linear(3, 32) + self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.Spatial_blocks = nn.ModuleList([ + Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.Spatial_norm = norm_layer(embed_dim_ratio) + self.Temporal_norm = norm_layer(embed_dim) + + ####### A easy way to implement weighted mean + self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1) + + self.head = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim , out_dim), + ) + + # wsx aciton_class_head + self.action_class_head = nn.Conv2d(290, num_class, kernel_size=1) + + # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + self.data_bn = nn.BatchNorm1d(3 * 25) + + + + + + def Spatial_forward_features(self, x): + b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints + x = rearrange(x, 'b c f p -> (b f) p c', ) + + x = self.Spatial_patch_to_embedding(x) + x += self.Spatial_pos_embed + x = self.pos_drop(x) + + for blk in self.Spatial_blocks: + x = blk(x) + + x = self.Spatial_norm(x) + x = rearrange(x, '(b f) w c -> b f (w c)', f=f) + return x + + def forward_features(self, x): + b = x.shape[0] + x += self.Temporal_pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + + x = self.Temporal_norm(x) + ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame + # x = self.weighted_mean(x) #wsx don't change all frame to one + # x = x.view(b, 1, -1) + return x + + def forward(self, x, x_target): + ''' + # x input shape [170, 81, 17, 2] + x = x.permute(0, 3, 1, 2) #[170, 2, 81, 17] + b, _, _, p = x.shape #[170, 2, 81, 17] b:batch_size p:joint_num + ### now x is [batch_size, 2 channels, receptive frames, joint_num], following image data + ''' + + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + x = self.Spatial_forward_features(x) + x = self.forward_features(x) # (2n, 290,800) + + # action_class_head + BatchN, FrameN, FutureN = x.size() + x = x.view(BatchN, FrameN, FutureN, 1) + x_class = F.avg_pool2d(x, x.size()[2:]) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) + x_class = self.action_class_head(x_class) + x_class = x_class.view(x_class.size(0), -1) + + + #action_class = x.permute(0,2,1) #[170, 544, 1] + #action_class = self.action_class_head(action_class) + #action_class = torch.squeeze(action_class) + #x = self.head(x) + #x = x.view(b, 1, p, -1) + + x_target = x_target.permute(0, 4, 1, 2, 3).contiguous().view(-1, 3, 10, 25) + + return x_class, x_target[::2] # [170,1,17,3] + diff --git a/processor/gpu.py b/processor/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/processor/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/processor/io.py b/processor/io.py index fb9e4f8..b778d8c 100644 --- a/processor/io.py +++ b/processor/io.py @@ -8,9 +8,9 @@ import torch.nn as nn import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class class IO(): @@ -31,7 +31,7 @@ def load_arg(self, argv=None): if p.config is not None: # load config file with open(p.config, 'r') as f: - default_arg = yaml.load(f) + default_arg = yaml.safe_load(f) # update parser from config file key = vars(p).keys() @@ -48,13 +48,13 @@ def init_environment(self): self.save_dir = os.path.join(self.arg.work_dir, self.arg.max_hop_dir, self.arg.lamda_act_dir) - self.io = torchlight.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) + self.io = torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) self.io.save_arg(self.arg) # gpu if self.arg.use_gpu: - gpus = torchlight.visible_gpu(self.arg.device) - torchlight.occupy_gpu(gpus) + gpus = torchlight.gpu.visible_gpu(self.arg.device) + #torchlight.occupy_gpu(gpus) self.gpus = gpus self.dev = "cuda:0" else: diff --git a/processor/processor.py b/processor/processor.py index 03fc2cf..690d846 100644 --- a/processor/processor.py +++ b/processor/processor.py @@ -8,9 +8,9 @@ import torch.optim as optim import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class from .io import IO diff --git a/processor/recognition.py b/processor/recognition.py index d3905af..cf23165 100644 --- a/processor/recognition.py +++ b/processor/recognition.py @@ -13,9 +13,9 @@ import torch.optim as optim import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class from .processor import Processor @@ -86,7 +86,7 @@ def nll_gaussian(self, preds, target, variance, add_const=False): return neg_log_p.sum() / (target.size(0) * target.size(1)) def kl_categorical(self, preds, log_prior, num_node, eps=1e-16): - kl_div = preds*(torch.log(preds+eps)-log_prior) + kl_ddiv = preds*(torch.log(preds+eps)-log_prior) return kl_div.sum()/(num_node*preds.size(0)) @@ -112,6 +112,7 @@ def train(self, training_A=False): self.epoch_info.clear() for data, data_downsample, target_data, data_last, label in loader: + # data: (32,3,290,25,2) data_downsample:(32,3,50,25,2) target_data:(32,3,10,25,2) data_last:(32,3,1,25,2) label:(32) data = data.float().to(self.dev) data_downsample = data_downsample.float().to(self.dev) label = label.long().to(self.dev) @@ -158,6 +159,7 @@ def train(self, training_A=False): label = label.long().to(self.dev) A_batch, prob, outputs, _ = self.model2(data_downsample) + # wsx x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) loss_class = self.loss_class(x_class, label) loss_recon = self.loss_pred(pred, target) diff --git a/torchlight/__init__.py b/torchlight/__init__.py new file mode 100644 index 0000000..07e70f1 --- /dev/null +++ b/torchlight/__init__.py @@ -0,0 +1,8 @@ +from .io import IO +from .io import str2bool +from .io import str2dict +from .io import DictAction +from .io import import_class +from .gpu import visible_gpu +from .gpu import occupy_gpu +from .gpu import ngpu diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..2136593 Binary files /dev/null and b/torchlight/__pycache__/__init__.cpython-36.pyc differ diff --git a/torchlight/__pycache__/io.cpython-36.pyc b/torchlight/__pycache__/io.cpython-36.pyc new file mode 100644 index 0000000..9c8518c Binary files /dev/null and b/torchlight/__pycache__/io.cpython-36.pyc differ diff --git a/torchlight/build/lib/torchlight/__init__.py b/torchlight/build/lib/torchlight/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/build/lib/torchlight/__init__.py @@ -0,0 +1 @@ + diff --git a/torchlight/build/lib/torchlight/gpu.py b/torchlight/build/lib/torchlight/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/torchlight/build/lib/torchlight/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/build/lib/torchlight/io.py b/torchlight/build/lib/torchlight/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/torchlight/build/lib/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/dist/torchlight-1.0-py3.6.egg b/torchlight/dist/torchlight-1.0-py3.6.egg new file mode 100644 index 0000000..46e5716 Binary files /dev/null and b/torchlight/dist/torchlight-1.0-py3.6.egg differ diff --git a/torchlight/gpu.py b/torchlight/gpu.py new file mode 100644 index 0000000..462faa9 --- /dev/null +++ b/torchlight/gpu.py @@ -0,0 +1,36 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + #os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2" + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/io.py b/torchlight/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/torchlight.egg-info/PKG-INFO b/torchlight/torchlight.egg-info/PKG-INFO new file mode 100644 index 0000000..53cafc2 --- /dev/null +++ b/torchlight/torchlight.egg-info/PKG-INFO @@ -0,0 +1,10 @@ +Metadata-Version: 1.0 +Name: torchlight +Version: 1.0 +Summary: A mini framework for pytorch +Home-page: UNKNOWN +Author: UNKNOWN +Author-email: UNKNOWN +License: UNKNOWN +Description: UNKNOWN +Platform: UNKNOWN diff --git a/torchlight/torchlight.egg-info/SOURCES.txt b/torchlight/torchlight.egg-info/SOURCES.txt new file mode 100644 index 0000000..1ee6009 --- /dev/null +++ b/torchlight/torchlight.egg-info/SOURCES.txt @@ -0,0 +1,8 @@ +setup.py +torchlight/__init__.py +torchlight/gpu.py +torchlight/io.py +torchlight.egg-info/PKG-INFO +torchlight.egg-info/SOURCES.txt +torchlight.egg-info/dependency_links.txt +torchlight.egg-info/top_level.txt \ No newline at end of file diff --git a/torchlight/torchlight.egg-info/dependency_links.txt b/torchlight/torchlight.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/torchlight.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/torchlight/torchlight.egg-info/top_level.txt b/torchlight/torchlight.egg-info/top_level.txt new file mode 100644 index 0000000..c600430 --- /dev/null +++ b/torchlight/torchlight.egg-info/top_level.txt @@ -0,0 +1 @@ +torchlight