From d3622f09ef091d05d77a0b90b50750a738ba2280 Mon Sep 17 00:00:00 2001 From: Bernd Illing <> Date: Fri, 22 Oct 2021 17:48:45 +0200 Subject: [PATCH] initial commit. added vision and video content --- LICENSE | 21 + README.md | 106 +++++ video/.gitignore | 6 + video/HowTo_tensorboard_files.txt | 13 + .../splits_classification/convert_paths.py | 25 ++ video/VGG_8.py | 264 ++++++++++++ video/augmentation.py | 381 +++++++++++++++++ video/commands.txt | 27 ++ video/dataset_3d.py | 235 +++++++++++ video/dataset_3d_lc.py | 280 +++++++++++++ video/env.yml | 69 ++++ video/env_setup.txt | 11 + video/finetune.py | 372 +++++++++++++++++ video/process_data/data/ucf101/classInd.txt | 101 +++++ video/process_data/readme.md | 51 +++ video/process_data/src/extract_frame.py | 111 +++++ video/process_data/src/write_csv.py | 115 ++++++ video/requirements.yml | 215 ++++++++++ video/run.py | 149 +++++++ video/test.py | 366 ++++++++++++++++ video/train.py | 87 ++++ video/utils.py | 44 ++ vision/.gitignore | 29 ++ vision/GreedyInfoMax/__init__.py | 0 vision/GreedyInfoMax/utils/__init__.py | 0 vision/GreedyInfoMax/utils/logger.py | 192 +++++++++ vision/GreedyInfoMax/utils/model_utils.py | 119 ++++++ vision/GreedyInfoMax/utils/utils.py | 69 ++++ vision/GreedyInfoMax/vision/__init__.py | 0 .../vision/arg_parser/__init__.py | 0 .../vision/arg_parser/arg_parser.py | 48 +++ .../vision/arg_parser/general_args.py | 73 ++++ .../vision/arg_parser/reload_args.py | 59 +++ .../vision/arg_parser/train_args.py | 213 ++++++++++ .../GreedyInfoMax/vision/compare_updates.py | 389 +++++++++++++++++ vision/GreedyInfoMax/vision/data/__init__.py | 0 .../vision/data/get_dataloader.py | 276 +++++++++++++ .../vision/downstream_classification.py | 222 ++++++++++ .../vision/get_acc_supervised.py | 68 +++ vision/GreedyInfoMax/vision/main_vision.py | 126 ++++++ .../vision/models/ClassificationModel.py | 34 ++ .../GreedyInfoMax/vision/models/FullModel.py | 194 +++++++++ .../vision/models/InfoNCE_Loss.py | 391 ++++++++++++++++++ .../vision/models/Supervised_Loss.py | 52 +++ .../vision/models/VGG_like_Encoder.py | 345 ++++++++++++++++ .../GreedyInfoMax/vision/models/__init__.py | 0 .../vision/models/load_vision_model.py | 49 +++ vision/GreedyInfoMax/vision/visualise.py | 294 +++++++++++++ vision/environment.yml | 101 +++++ vision/scripts/class_from_inter_layers.sh | 15 + vision/scripts/vision_traineval_CLAPP.sh | 11 + .../vision_traineval_CLAPP_s_sym_W_pred.sh | 11 + .../scripts/vision_traineval_HingeLossCPC.sh | 12 + vision/setup_dependencies.sh | 5 + 54 files changed, 6446 insertions(+) create mode 100644 LICENSE create mode 100755 README.md create mode 100644 video/.gitignore create mode 100755 video/HowTo_tensorboard_files.txt create mode 100644 video/UCF101/splits_classification/convert_paths.py create mode 100644 video/VGG_8.py create mode 100644 video/augmentation.py create mode 100644 video/commands.txt create mode 100644 video/dataset_3d.py create mode 100644 video/dataset_3d_lc.py create mode 100755 video/env.yml create mode 100755 video/env_setup.txt create mode 100644 video/finetune.py create mode 100644 video/process_data/data/ucf101/classInd.txt create mode 100644 video/process_data/readme.md create mode 100644 video/process_data/src/extract_frame.py create mode 100644 video/process_data/src/write_csv.py create mode 100755 video/requirements.yml create mode 100644 video/run.py create mode 100644 video/test.py create mode 100644 video/train.py create mode 100644 video/utils.py create mode 100644 vision/.gitignore create mode 100755 vision/GreedyInfoMax/__init__.py create mode 100755 vision/GreedyInfoMax/utils/__init__.py create mode 100755 vision/GreedyInfoMax/utils/logger.py create mode 100755 vision/GreedyInfoMax/utils/model_utils.py create mode 100755 vision/GreedyInfoMax/utils/utils.py create mode 100755 vision/GreedyInfoMax/vision/__init__.py create mode 100755 vision/GreedyInfoMax/vision/arg_parser/__init__.py create mode 100755 vision/GreedyInfoMax/vision/arg_parser/arg_parser.py create mode 100755 vision/GreedyInfoMax/vision/arg_parser/general_args.py create mode 100755 vision/GreedyInfoMax/vision/arg_parser/reload_args.py create mode 100755 vision/GreedyInfoMax/vision/arg_parser/train_args.py create mode 100755 vision/GreedyInfoMax/vision/compare_updates.py create mode 100755 vision/GreedyInfoMax/vision/data/__init__.py create mode 100755 vision/GreedyInfoMax/vision/data/get_dataloader.py create mode 100755 vision/GreedyInfoMax/vision/downstream_classification.py create mode 100755 vision/GreedyInfoMax/vision/get_acc_supervised.py create mode 100755 vision/GreedyInfoMax/vision/main_vision.py create mode 100755 vision/GreedyInfoMax/vision/models/ClassificationModel.py create mode 100755 vision/GreedyInfoMax/vision/models/FullModel.py create mode 100755 vision/GreedyInfoMax/vision/models/InfoNCE_Loss.py create mode 100755 vision/GreedyInfoMax/vision/models/Supervised_Loss.py create mode 100755 vision/GreedyInfoMax/vision/models/VGG_like_Encoder.py create mode 100755 vision/GreedyInfoMax/vision/models/__init__.py create mode 100755 vision/GreedyInfoMax/vision/models/load_vision_model.py create mode 100755 vision/GreedyInfoMax/vision/visualise.py create mode 100644 vision/environment.yml create mode 100755 vision/scripts/class_from_inter_layers.sh create mode 100755 vision/scripts/vision_traineval_CLAPP.sh create mode 100755 vision/scripts/vision_traineval_CLAPP_s_sym_W_pred.sh create mode 100755 vision/scripts/vision_traineval_HingeLossCPC.sh create mode 100644 vision/setup_dependencies.sh diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d1ce6b4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Bernd Illing + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100755 index 0000000..9b9e54e --- /dev/null +++ b/README.md @@ -0,0 +1,106 @@ + + + +This is the code for the publication: + +B. Illing, J. Ventura, G. Bellec & W. Gerstner +[*Local plasticity rules can learn deep representations using self-supervised contrastive predictions*](https://arxiv.org/abs/2010.08262), accepted at NeurIPS 2021 + +Contact: +[bernd.illing@epfl.ch](mailto:bernd.illing@epfl.ch) + +# Structure of the code + +The code is divided into three independent sections, corresponding to the three applications we apply CLAPP to: + +* vision +* video +* audio + +Our implementation requires the following general dependenices: + +* python 3 +* conda + +Each section comes with its own dependencies handled by conda environments, as explained in the respective sections below. + +# Vision + +The implementation of the CLAPP vision experiments is based on Sindy Löwe's code of the [Greedy InfoMax model](https://github.com/loeweX/Greedy_InfoMax). + +## Setup + +To setup the conda environment, simply run + +```bash + bash ./vision/setup_dependencies.sh +``` + +To activate and deactive the created conda environment, run + +```bash + conda activate infomax + conda deactivate +``` + +respectively. The environment name `infomax`, as well as the name of our python module `GreedyInfoMax`, are GIM code legacy. + +## Usage + +We included three sample scripts to run CLAPP, CLAPP-s (synchronous pos. and neg. updates; version with weight symmetry in $W^{pred}$) and Hinge Loss CPC (end-to-end version of CLAPP). To run the, e.g. the Hinge Loss CPC simulations (model training + evaluation), run: + +```bash + bash ./vision/scripts/vision_traineval_HingeLossCPC.sh +``` + +The code includes many (experimental) versions of CLAPP as command line options that are not used and mentioned in the paper. To view all command-line options of model training, run: + +```bash + cd vision + python -m GreedyInfoMax.vision.main_vision --help +``` + +Training in general uses auto-differentiation provided by `pytorch`. We checked that the obtained updates are equivalent to evaluating the CLAPP learning rules for $W$ and $W^{pred}$, Equations (6) - (8). The used code for this sanity check can be found in `./vision/GreedyInfoMax/vision/compare_updates.py`. + + +# Video + +The implementation of the CLAPP video experiments was inspired by Tengda Han's code for [Dense Predictive Coding](https://github.com/TengdaHan/DPC) + +## Setup + +The setup of the conda environment is described in `./video/env_setup.txt`. To activate and deactive the created conda environment `pdm`, run + +```bash + conda activate pdm + conda deactivate +``` + +respectively. + +## Usage + +The basic simulations described in the paper can be replicated using the commands listed in `./video/commands.txt`. + + +# Audio + +The implementation of the CLAPP audio experiments is based on Sindy Löwe's code of the [Greedy InfoMax model](https://github.com/loeweX/Greedy_InfoMax). + + +## Setup + +## Usage + +# Cite + +Please cite our paper if you use this code in your own work: + +``` +@article{illing2020local, + title={Local plasticity rules can learn deep representations using self-supervised contrastive predictions}, + author={Illing, Bernd and Ventura, Jean and Bellec, Guillaume and Gerstner, Wulfram}, + journal={arXiv preprint arXiv:2010.08262}, + year={2020} +} +``` diff --git a/video/.gitignore b/video/.gitignore new file mode 100644 index 0000000..c7e42bc --- /dev/null +++ b/video/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +*.tar +*.rar +UCF101/videos/ +UCF101/frame/ +.DS_Store diff --git a/video/HowTo_tensorboard_files.txt b/video/HowTo_tensorboard_files.txt new file mode 100755 index 0000000..770518d --- /dev/null +++ b/video/HowTo_tensorboard_files.txt @@ -0,0 +1,13 @@ +$ ipython + +>> from tensorboard.backend.event_processing import event_accumulator +>> ea = event_accumulator.EventAccumulator('path+/events.out.tfevents.xx.xx') +(e.g. ea = event_accumulator.EventAccumulator('./temp_VGG_CLAPP_test/classification_all_layers/val/events.out.tfevents.1614370774.illing-clapp-video') ) +>> ea.Reload() +>> ea.Tags() + +-> ready to access + +e.g. +>> ea.Scalars('global/accuracy_4_top_1')[-10:] + diff --git a/video/UCF101/splits_classification/convert_paths.py b/video/UCF101/splits_classification/convert_paths.py new file mode 100644 index 0000000..30084b3 --- /dev/null +++ b/video/UCF101/splits_classification/convert_paths.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Dec 10 12:33:53 2020 + +@author: Jean +""" + +import os + +if __name__ == '__main__': + + files = os.listdir('./')[1:] + print(files) + + for file_name in files: + with open(file_name, 'r') as stream: + paths = stream.readlines() + print(paths[0]) + for path in paths: + path.replace('\\','/') + print(paths[0]) + with open(file_name, 'w') as stream: + stream.writelines(paths) + + \ No newline at end of file diff --git a/video/VGG_8.py b/video/VGG_8.py new file mode 100644 index 0000000..d0ff93c --- /dev/null +++ b/video/VGG_8.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn + + +# COMPUTE HINGE LOSS FOR CLAPP +class HingeLoss(nn.Module): + def __init__(self): + super(HingeLoss, self).__init__() + + def forward(self, input, positive): + # positive IS USED TO KNOW THE POSITIVES IN THE SCORES MATRIX (IN OUR CASE, IT'S THE DIAGONAL) + input[positive, positive] *= -1 + # APPLYING THE MAX(1+INPUT,0) + input = torch.clamp(1+input, min=0) + # SUMMATION OF THE 2 LOSS COMPONENTS + loss = 0.5*(torch.mean(input[positive, positive]) + (torch.sum(input)- torch.sum(input[positive, positive]))/((positive.size(0)-1)*(positive.size(0)))) + return loss + + +# COMPUTE THE TOP-K ACCURACIES +class top_k(nn.Module): + def __init__(self, k): + super(top_k, self).__init__() + # PUT K IN FORM OF LIST IF K IS SINGLE INT VALUE + k = [k] if isinstance(k, int) else k + self.k=k + + def forward(self, input): + accs = [] + #FIND TOP-K FOR EVERY K IN LIST + for k in self.k: + acc_k = [] + # FOR EACH TIMESTEP THAT WE WANT TO PREDICT + for time_pred in input: + # POSITIVE IS DIAGONAL, SPOT EVERY TIME THE DIAG ELEMENT IS IN THE TOP-K + acc_k.append(torch.mean(torch.tensor([(index == input_line).any().float() for (index, input_line) in enumerate(torch.topk(time_pred, k, dim=1)[1])]))) + # AVERAGE THE SCORE OVER ALL TIME STEPS WE TRY TO PREDICT + accs.append(torch.mean(torch.tensor(acc_k))) + return accs + + + +class Loss(nn.Module): + def __init__(self, mode, spatial_collapse, single_predictor, spatial_segm, predictor_bias, channels): + super(Loss, self).__init__() + # DEPENDING OF mode WE USE A DIFFERENT CATEGORICAL LOSS + if mode =='CPC' or mode == 'GIM': + self.loss = nn.CrossEntropyLoss() + else: + self.loss = HingeLoss() + + # SETS THE USE OF ONE PREDICTOR APPLIED RECURSIVELY FOR ALL TIMESTEPS OR MULTIPLE PREDICTORS + self.single_predictor = single_predictor + # DEFINES IF THE SPATIAL MAPS ARE GOING TO BE SEGMENTED TO CREATE SPATIAL NEGATIVES (DPC TECHNIQUE) + self.spatial_segm = spatial_segm + # DEFINES IF THE ACTIVATION MAPS ARE POOLED TO FORM Z (ORIGINAL CPC TECHNIQUE, NOT USED HERE EVEN FOR CPC) + self.spatial_collapse = spatial_collapse + # ADD A BIAS TO THE PREDICTION OPERATOR (NEVER USED) + self.predictor_bias = predictor_bias + + if self.single_predictor: + self.W = nn.Conv3d(channels, channels, kernel_size=(1,1,1), bias=self.predictor_bias) + else: + self.W = nn.ModuleList() + for i in range(3): # number of pred_steps + self.W.append(nn.Conv3d(channels, channels, kernel_size=(1,1,1), bias=self.predictor_bias)) + + + # FOR FUTURE WORK: CREATION OF THE MASK IDENTIFYING THE TYPE OF SAMPLE: POS, TEMP. NEG., BATCH. NEG. + self.mask_computed=False + + self._initialize_weights() + + + def _initialize_weights(self): + if self.single_predictor: + nn.init.kaiming_normal_(self.W.weight, mode='fan_out', nonlinearity='relu') + if self.W.bias is not None: + nn.init.constant_(self.W.bias, 0) + else: + for m in self.W: + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + def forward(self, input): + + if self.spatial_collapse: + # IF SPATIAL COLLAPSE, PASS THROUGH POOLING + collapse = nn.AvgPool3d(kernel_size=(1, input.size(-1), input.size(-1))) + input = collapse(input) + + # FIRST T-3 FRAMES TO USE FOR PREDICTION + state = input[:,:,:-3] + # TARGET FOR PREDCTING 1 TIME STEP AHEAD + first_targ = input[:,:,1:-2].permute(1,3,4,0,2) + # TARGET FOR PREDCTING 2 TIME STEPS AHEAD + second_targ = input[:,:,2:-1].permute(1,3,4,0,2) + # TARGET FOR PREDCTING 3 TIME STEPS AHEAD + third_targ = input[:,:,3:].permute(1,3,4,0,2) + + #PERMUTATION: (B,C,(T-3),X,Y) -> (C,X,Y,B,(T-3)) + if self.single_predictor: + # APPLY PREDICTOR RECURSIVELY + first_pred = self.W(state).permute(1,3,4,0,2) + second_pred = self.W(self.W(state)).permute(1,3,4,0,2) + third_pred = self.W(self.W(self.W(state))).permute(1,3,4,0,2) + + + else: + first_pred = self.W[0](state).permute(1,3,4,0,2) + second_pred = self.W[1](state).permute(1,3,4,0,2) + third_pred = self.W[2](state).permute(1,3,4,0,2) + + # (C,X,Y,B,(T-3)) -> (CxXxY, Bx(T-3)) (FLATTENING ACTIVATION MAPS) or (C,XxYxBx(T-3)) (DPC) + if self.spatial_segm: + index = 1 + else: + index = 3 + + # PERFORM THE FLATTENING DEPENDING ON THE USE OF SPATIAL NEGATIVES + first_targ = torch.flatten(torch.flatten(first_targ, start_dim=index), end_dim=index-1) + second_targ = torch.flatten(torch.flatten(second_targ, start_dim=index), end_dim=index-1) + third_targ = torch.flatten(torch.flatten(third_targ, start_dim=index), end_dim=index-1) + + first_pred = torch.flatten(torch.flatten(first_pred, start_dim=index), end_dim=index-1) + second_pred = torch.flatten(torch.flatten(second_pred, start_dim=index), end_dim=index-1) + third_pred = torch.flatten(torch.flatten(third_pred, start_dim=index), end_dim=index-1) + + # COMPUTING THE SCORE BY MATRIX MULTIPLICATION + first_score = torch.matmul(first_targ.transpose(0,1),first_pred).transpose(0,1) + second_score = torch.matmul(second_targ.transpose(0,1),second_pred).transpose(0,1) + third_score = torch.matmul(third_targ.transpose(0,1),third_pred).transpose(0,1) + + # POSITIVE SAMPLES ARE THE DIAGONAL + positive = torch.arange(0,first_score.size(0)).cuda() + + return (self.loss(first_score, positive)+self.loss(second_score, positive)+self.loss(third_score, positive))/3, [first_score.detach(), second_score.detach(), third_score.detach()] + + + +# CLASS FOR THE CONV+RELU+BN MODULE, NAME IS MISLEADING BUT BN IS LAST IN MODULE +class ConvBNReLU(nn.Module): + def __init__(self,in_ch, out_ch, k_size, stride, padding): + super(ConvBNReLU, self).__init__() + + self.conv = nn.Conv3d(in_ch, out_ch, k_size, stride=stride, padding=padding) + self.norm = nn.BatchNorm3d(out_ch) + self.relu = nn.ReLU(inplace=True) + + self.layers = nn.Sequential(self.conv, self.relu, self.norm) + + self._initialize_weights() + + def forward(self, x): + return self.layers(x) + + def _initialize_weights(self): + for m in self.layers: + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +# CLASS FOR THE VGG-8 NETWORK +class VGG_8(nn.Module): + def __init__(self, temp ,mode, spatial_collapse, single_predictor, spatial_segm, predictor_bias, no_ss_loss=False): + super(VGG_8, self).__init__() + + self.mode = mode + self.no_loss = no_ss_loss + + in_channels = 3 + + #IF temp, SET THE LAST 2 CONVOLUTIONS WITH TEMPORAL KERNELS AND CONSEQUENT STRIDE + if temp: + time_kernel=3 + else: + time_kernel=1 + + self.conv1 = ConvBNReLU(in_channels, 96,(1, 7, 7), stride=(1,2,2), padding=0) + self.maxpool1 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)) + self.block1 = nn.Sequential(self.conv1, self.maxpool1) + + self.conv2 = ConvBNReLU(96, 256, (1,5,5), stride=(1,2,2), padding=(0,1,1)) + self.maxpool2 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)) + self.block2 = nn.Sequential(self.conv2, self.maxpool2) + + + self.conv3 = ConvBNReLU(256, 512, (1,3,3),stride=(1,1,1),padding=(0,1,1)) + self.block3 = nn.Sequential(self.conv3) + self.conv4 = ConvBNReLU(512, 512, (time_kernel,3,3),stride=(time_kernel,1,1), padding=(0,1,1)) + self.block4 = nn.Sequential(self.conv4) + self.conv5 = ConvBNReLU(512, 512, (time_kernel,3,3),stride=(time_kernel,1,1), padding=(0,1,1)) + self.maxpool3 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)) + self.block5 = nn.Sequential(self.conv5, self.maxpool3) + + # LIST OF NETWORK BLOCKS, IF PER-LAYER, LOSS APPLIED TO EACH BLOCK + self.blocks = nn.ModuleList([self.block1, self.block2 , self.block3, self.block4, self.block5]) + + self.losses = nn.ModuleList() + self.accs = nn.ModuleList() + + if self.mode =='GIM' or self.mode =='CLAPP': + for block in self.blocks: + self.losses.append(Loss(mode, spatial_collapse, single_predictor, spatial_segm, predictor_bias, block[0].conv.weight.size(0))) + self.accs.append(top_k([1,3,5])) + else: + self.losses.append(Loss(mode, spatial_collapse, single_predictor, spatial_segm, predictor_bias, self.blocks[-1][0].conv.weight.size(0))) + self.accs.append(top_k([1,3,5])) + + def get_nb_losses(self): + return len(self.losses) + + + def forward(self, input): + self.losses_val = [] + self.accs_val = [] + + res = input + for i, block in enumerate(self.blocks): + if self.mode =='CLAPP' or self.mode=='GIM': + # IF CLAPP OR GIM, DISCONNECT INPUT + res = block(res.detach()) + if self.no_loss: + # OPTION TO PREVENT SELF-SUPERVISION LOSS IF CLASSIFICATION + loss = None + accs= None + else: + # COMPUTE LOSS AND SCORES + loss, scores = self.losses[i](res) + accs = self.accs[i](scores) + self.losses_val.append(loss) + self.accs_val.append(accs) + + else: + # IF NOT CLAPP OR GIM, FORWARD NORMALLY + res = block(res) + + + if self.mode =='CPC' or self.mode=='HingeCPC': + # IF HINGECPC OR CPC, COMPUTE LOSS AND SCORES WITH LAST OUTPUT + if self.no_loss: + loss=None + accs=None + else: + # INEX [0] BECAUSE LOSSES AND ACCS ARE ALWAYS A LIST + loss, scores = self.losses[0](res) + accs = self.accs[0](scores) + self.losses_val.append(loss) + self.accs_val.append(accs) + + return self.losses_val, self.accs_val, res + + + + \ No newline at end of file diff --git a/video/augmentation.py b/video/augmentation.py new file mode 100644 index 0000000..6d6f49e --- /dev/null +++ b/video/augmentation.py @@ -0,0 +1,381 @@ +import random +import numbers +import math +import collections +import numpy as np +from PIL import ImageOps, Image +from joblib import Parallel, delayed + +import torchvision +from torchvision import transforms +import torchvision.transforms.functional as F + +class Padding: + def __init__(self, pad): + self.pad = pad + + def __call__(self, img): + return ImageOps.expand(img, border=self.pad, fill=0) + +class Scale: + def __init__(self, size, interpolation=Image.NEAREST): + assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, imgmap): + # assert len(imgmap) > 1 # list of images + img1 = imgmap[0] + if isinstance(self.size, int): + w, h = img1.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + return imgmap + if w < h: + ow = self.size + oh = int(self.size * h / w) + return [i.resize((ow, oh), self.interpolation) for i in imgmap] + else: + oh = self.size + ow = int(self.size * w / h) + return [i.resize((ow, oh), self.interpolation) for i in imgmap] + else: + return [i.resize(self.size, self.interpolation) for i in imgmap] + + +class CenterCrop: + def __init__(self, size, consistent=True): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, imgmap): + img1 = imgmap[0] + w, h = img1.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] + + +class RandomCropWithProb: + def __init__(self, size, p=0.8, consistent=True): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.consistent = consistent + self.threshold = p + + def __call__(self, imgmap): + img1 = imgmap[0] + w, h = img1.size + if self.size is not None: + th, tw = self.size + if w == tw and h == th: + return imgmap + if self.consistent: + if random.random() < self.threshold: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + else: + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] + else: + result = [] + for i in imgmap: + if random.random() < self.threshold: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + else: + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + result.append(i.crop((x1, y1, x1 + tw, y1 + th))) + return result + else: + return imgmap + +class RandomCrop: + def __init__(self, size, consistent=True): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.consistent = consistent + + def __call__(self, imgmap, flowmap=None): + img1 = imgmap[0] + w, h = img1.size + if self.size is not None: + th, tw = self.size + if w == tw and h == th: + return imgmap + if not flowmap: + if self.consistent: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] + else: + result = [] + for i in imgmap: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + result.append(i.crop((x1, y1, x1 + tw, y1 + th))) + return result + elif flowmap is not None: + assert (not self.consistent) + result = [] + for idx, i in enumerate(imgmap): + proposal = [] + for j in range(3): # number of proposal: use the one with largest optical flow + x = random.randint(0, w - tw) + y = random.randint(0, h - th) + proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))]) + [x1, y1, _] = max(proposal, key=lambda x: x[-1]) + result.append(i.crop((x1, y1, x1 + tw, y1 + th))) + return result + else: + raise ValueError('wrong case') + else: + return imgmap + + +class RandomSizedCrop: + def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): + self.size = size + self.interpolation = interpolation + self.consistent = consistent + self.threshold = p + + def __call__(self, imgmap): + img1 = imgmap[0] + if random.random() < self.threshold: # do RandomSizedCrop + for attempt in range(10): + area = img1.size[0] * img1.size[1] + target_area = random.uniform(0.5, 1) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if self.consistent: + if random.random() < 0.5: + w, h = h, w + if w <= img1.size[0] and h <= img1.size[1]: + x1 = random.randint(0, img1.size[0] - w) + y1 = random.randint(0, img1.size[1] - h) + + imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] + for i in imgmap: assert(i.size == (w, h)) + + return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] + else: + result = [] + for i in imgmap: + if random.random() < 0.5: + w, h = h, w + if w <= img1.size[0] and h <= img1.size[1]: + x1 = random.randint(0, img1.size[0] - w) + y1 = random.randint(0, img1.size[1] - h) + result.append(i.crop((x1, y1, x1 + w, y1 + h))) + assert(result[-1].size == (w, h)) + else: + result.append(i) + + assert len(result) == len(imgmap) + return [i.resize((self.size, self.size), self.interpolation) for i in result] + + # Fallback + scale = Scale(self.size, interpolation=self.interpolation) + crop = CenterCrop(self.size) + return crop(scale(imgmap)) + else: # don't do RandomSizedCrop, do CenterCrop + crop = CenterCrop(self.size) + return crop(imgmap) + + +class RandomHorizontalFlip: + def __init__(self, consistent=True, command=None): + self.consistent = consistent + if command == 'left': + self.threshold = 0 + elif command == 'right': + self.threshold = 1 + else: + self.threshold = 0.5 + def __call__(self, imgmap): + if self.consistent: + if random.random() < self.threshold: + return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] + else: + return imgmap + else: + result = [] + for i in imgmap: + if random.random() < self.threshold: + result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) + else: + result.append(i) + assert len(result) == len(imgmap) + return result + + +class RandomGray: + '''Actually it is a channel splitting, not strictly grayscale images''' + def __init__(self, consistent=True, p=0.5): + self.consistent = consistent + self.p = p # probability to apply grayscale + def __call__(self, imgmap): + if self.consistent: + if random.random() < self.p: + return [self.grayscale(i) for i in imgmap] + else: + return imgmap + else: + result = [] + for i in imgmap: + if random.random() < self.p: + result.append(self.grayscale(i)) + else: + result.append(i) + assert len(result) == len(imgmap) + return result + + def grayscale(self, img): + channel = np.random.choice(3) + np_img = np.array(img)[:,:,channel] + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, 'RGB') + return img + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + self.consistent = consistent + self.threshold = p + + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + Arguments are same as that of __init__. + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + random.shuffle(transforms) + transform = torchvision.transforms.Compose(transforms) + + return transform + + def __call__(self, imgmap): + if random.random() < self.threshold: # do ColorJitter + if self.consistent: + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + return [transform(i) for i in imgmap] + else: + result = [] + for img in imgmap: + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + result.append(transform(img)) + return result + else: # don't do ColorJitter, do nothing + return imgmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) + return format_string + + +class RandomRotation: + def __init__(self, consistent=True, degree=15, p=1.0): + self.consistent = consistent + self.degree = degree + self.threshold = p + def __call__(self, imgmap): + if random.random() < self.threshold: # do RandomRotation + if self.consistent: + deg = np.random.randint(-self.degree, self.degree, 1)[0] + return [i.rotate(deg, expand=True) for i in imgmap] + else: + return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] + else: # don't do RandomRotation, do nothing + return imgmap + +class ToTensor: + def __call__(self, imgmap): + totensor = transforms.ToTensor() + return [totensor(i) for i in imgmap] + +class Normalize: + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + def __call__(self, imgmap): + normalize = transforms.Normalize(mean=self.mean, std=self.std) + return [normalize(i) for i in imgmap] + + diff --git a/video/commands.txt b/video/commands.txt new file mode 100644 index 0000000..934f98c --- /dev/null +++ b/video/commands.txt @@ -0,0 +1,27 @@ + +Commands self-supervised training: + +python run.py --seq_len 8 --mode HingeCPC --batch_size 16 --epochs 300 --gpu 0 --name std_VGG_HingeCPC +python run.py --seq_len 8 --mode CLAPP --batch_size 16 --epochs 300 --gpu 1 --name std_VGG_CLAPP +python run.py --seq_len 8 --mode GIM --batch_size 16 --epochs 300 --gpu 1 --name std_VGG_GIM +python run.py --seq_len 8 --mode CPC --batch_size 16 --epochs 300 --gpu 0 --name std_VGG_CPC + +python run.py --seq_len 54 --mode HingeCPC --temp_VGG --batch_size 8 --epochs 300 --gpu 0 --name temp_VGG_HingeCPC +python run.py --seq_len 54 --mode CLAPP --temp_VGG --batch_size 8 --epochs 300 --gpu 0 --name temp_VGG_CLAPP +python run.py --seq_len 54 --mode GIM --temp_VGG --batch_size 8 --epochs 300 --gpu 1 --name temp_VGG_GIM +python run.py --seq_len 54 --mode CPC --temp_VGG --batch_size 8 --epochs 300 --gpu 1 --name temp_VGG_CPC + + +Commands classification with frozen encoder: + +python test.py --seq_len 8 --mode HingeCPC --batch_size 16 --epochs 200 --gpu 0 --lr 1e-4 --name std_VGG_HingeCPC --monitor_all_layers +python test.py --seq_len 8 --mode CLAPP --batch_size 16 --epochs 200 --gpu 1 --lr 1e-4 --name std_VGG_CLAPP --monitor_all_layers +python test.py --seq_len 8 --mode GIM --batch_size 16 --epochs 200 --gpu 1 --lr 1e-4 --name std_VGG_GIM --monitor_all_layers +python test.py --seq_len 8 --mode CPC --batch_size 16 --epochs 200 --gpu 0 --lr 1e-4 --name std_VGG_CPC --monitor_all_layers +python test.py --seq_len 8 --mode CPC --batch_size 16 --epochs 200 --gpu 0 --lr 1e-4 --name std_VGG_random --monitor_all_layers + +python test.py --seq_len 72 --mode HingeCPC --temp_VGG --batch_size 8 --epochs 200 --gpu 1 --lr 1e-4 --name temp_VGG_HingeCPC --monitor_all_layers +python test.py --seq_len 72 --mode CLAPP --temp_VGG --batch_size 8 --epochs 200 --gpu 1 --lr 1e-4 --name temp_VGG_CLAPP --monitor_all_layers +python test.py --seq_len 72 --mode GIM --temp_VGG --batch_size 8 --epochs 200 --gpu 1 --lr 1e-4 --name temp_VGG_GIM --monitor_all_layers +python test.py --seq_len 72 --mode CPC --temp_VGG --batch_size 8 --epochs 200 --gpu 0 --lr 1e-4 --name temp_VGG_CPC --monitor_all_layers +python test.py --seq_len 72 --mode CPC --temp_VGG --batch_size 8 --epochs 200 --gpu 0 --lr 1e-4 --name temp_VGG_random --monitor_all_layers \ No newline at end of file diff --git a/video/dataset_3d.py b/video/dataset_3d.py new file mode 100644 index 0000000..4b8cf0c --- /dev/null +++ b/video/dataset_3d.py @@ -0,0 +1,235 @@ +import torch +from torch.utils import data +from torchvision import transforms +import os +import sys +import time +import pickle +import glob +import csv +import pandas as pd +import numpy as np +import cv2 +sys.path.append('../utils') +from augmentation import * +from tqdm import tqdm +from joblib import Parallel, delayed + +def pil_loader(path): + with open(path, 'rb') as f: + with Image.open(f) as img: + return img.convert('RGB') + + +class Kinetics400_full_3d(data.Dataset): + def __init__(self, + mode='train', + transform=None, + seq_len=10, + num_seq=5, + downsample=3, + epsilon=5, + unit_test=False, + big=False, + return_label=False): + self.mode = mode + self.transform = transform + self.seq_len = seq_len + self.num_seq = num_seq + self.downsample = downsample + self.epsilon = epsilon + self.unit_test = unit_test + self.return_label = return_label + + if big: print('Using Kinetics400 full data (256x256)') + else: print('Using Kinetics400 full data (150x150)') + + # get action list + self.action_dict_encode = {} + self.action_dict_decode = {} + action_file = os.path.join('process_data/data/kinetics400', 'classInd.txt') + action_df = pd.read_csv(action_file, sep=',', header=None) + for _, row in action_df.iterrows(): + act_id, act_name = row + act_id = int(act_id) - 1 # let id start from 0 + self.action_dict_decode[act_id] = act_name + self.action_dict_encode[act_name] = act_id + + # splits + if big: + if mode == 'train': + split = 'process_data/data/kinetics400_256/train_split.csv' + video_info = pd.read_csv(split, header=None) + elif (mode == 'val') or (mode == 'test'): + split = 'process_data/data/kinetics400_256/val_split.csv' + video_info = pd.read_csv(split, header=None) + else: raise ValueError('wrong mode') + else: # small + if mode == 'train': + split = 'process_data/data/kinetics400/train_split.csv' + video_info = pd.read_csv(split, header=None) + elif (mode == 'val') or (mode == 'test'): + split = 'process_data/data/kinetics400/val_split.csv' + video_info = pd.read_csv(split, header=None) + else: raise ValueError('wrong mode') + + drop_idx = [] + print('filter out too short videos ...') + for idx, row in tqdm(video_info.iterrows(), total=len(video_info)): + vpath, vlen = row + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: + drop_idx.append(idx) + self.video_info = video_info.drop(drop_idx, axis=0) + + if mode == 'val': self.video_info = self.video_info.sample(frac=0.3, random_state=666) + if self.unit_test: self.video_info = self.video_info.sample(32, random_state=666) + # shuffle not necessary because use RandomSampler + + def idx_sampler(self, vlen, vpath): + '''sample index from a video''' + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] + n = 1 + start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) + seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx + seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample + return [seq_idx_block, vpath] + + def __getitem__(self, index): + vpath, vlen = self.video_info.iloc[index] + items = self.idx_sampler(vlen, vpath) + if items is None: print(vpath) + + idx_block, vpath = items + assert idx_block.shape == (self.num_seq, self.seq_len) + idx_block = idx_block.reshape(self.num_seq*self.seq_len) + + seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] + t_seq = self.transform(seq) # apply same transform + + (C, H, W) = t_seq[0].size() + t_seq = torch.stack(t_seq, 0) + t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) + + if self.return_label: + try: + vname = vpath.split('/')[-3] + vid = self.encode_action(vname) + except: + vname = vpath.split('/')[-2] + vid = self.encode_action(vname) + + label = torch.LongTensor([vid]) + return t_seq, label + + return t_seq + + def __len__(self): + return len(self.video_info) + + def encode_action(self, action_name): + '''give action name, return category''' + return self.action_dict_encode[action_name] + + def decode_action(self, action_code): + '''give action code, return action name''' + return self.action_dict_decode[action_code] + + +class UCF101_3d(data.Dataset): + def __init__(self, + mode='train', + transform=None, + seq_len=10, + num_seq = 5, + downsample=3, + epsilon=5, + which_split=1, + return_label=False): + self.mode = mode + self.transform = transform + self.seq_len = seq_len + self.num_seq = num_seq + self.downsample = downsample + self.epsilon = epsilon + self.which_split = which_split + self.return_label = return_label + + # splits + if mode == 'train': + split = 'process_data/data/ucf101/train_split%02d.csv' % self.which_split + video_info = pd.read_csv(split, header=None) + elif (mode == 'val') or (mode == 'test'): # use val for test + split = 'process_data/data/ucf101/test_split%02d.csv' % self.which_split + video_info = pd.read_csv(split, header=None) + else: raise ValueError('wrong mode') + + # get action list + self.action_dict_encode = {} + self.action_dict_decode = {} + action_file = os.path.join('process_data/data/ucf101', 'classInd.txt') + action_df = pd.read_csv(action_file, sep=' ', header=None) + for _, row in action_df.iterrows(): + act_id, act_name = row + self.action_dict_decode[act_id] = act_name + self.action_dict_encode[act_name] = act_id + + # filter out too short videos: + drop_idx = [] + for idx, row in video_info.iterrows(): + vpath, vlen = row + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: + drop_idx.append(idx) + self.video_info = video_info.drop(drop_idx, axis=0) + + if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) + # shuffle not required due to external sampler + + def idx_sampler(self, vlen, vpath): + '''sample index from a video''' + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] + n = 1 + start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) + seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx + seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample + return [seq_idx_block, vpath] + + + def __getitem__(self, index): + vpath, vlen = self.video_info.iloc[index] + items = self.idx_sampler(vlen, vpath) + if items is None: print(vpath) + + idx_block, vpath = items + assert idx_block.shape == (self.num_seq, self.seq_len) + idx_block = idx_block.reshape(self.num_seq*self.seq_len) + + seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] + t_seq = self.transform(seq) # apply same transform + + (C, H, W) = t_seq[0].size() + t_seq = torch.stack(t_seq, 0) + t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) + + if self.return_label: + try: + vname = vpath.split('/')[-3] + vid = self.encode_action(vname) + except: + vname = vpath.split('/')[-2] + vid = self.encode_action(vname) + label = torch.LongTensor([vid]) + return t_seq, label + + return t_seq + + def __len__(self): + return len(self.video_info) + + def encode_action(self, action_name): + '''give action name, return action code''' + return self.action_dict_encode[action_name] + + def decode_action(self, action_code): + '''give action code, return action name''' + return self.action_dict_decode[action_code] + diff --git a/video/dataset_3d_lc.py b/video/dataset_3d_lc.py new file mode 100644 index 0000000..90dc342 --- /dev/null +++ b/video/dataset_3d_lc.py @@ -0,0 +1,280 @@ +import torch +from torch.utils import data +from torchvision import transforms +import os +import sys +import time +import pickle +import csv +import glob +import pandas as pd +import numpy as np +import cv2 +sys.path.append('../utils') +from augmentation import * +from tqdm import tqdm +from joblib import Parallel, delayed + +def pil_loader(path): + with open(path, 'rb') as f: + with Image.open(f) as img: + return img.convert('RGB') + +class UCF101_3d(data.Dataset): + def __init__(self, + mode='train', + transform=None, + seq_len=10, + num_seq =1, + downsample=3, + epsilon=5, + which_split=1): + self.mode = mode + self.transform = transform + self.seq_len = seq_len + self.num_seq = num_seq + self.downsample = downsample + self.epsilon = epsilon + self.which_split = which_split + + # splits + if mode == 'train': + split = 'process_data/data/ucf101/train_split%02d.csv' % self.which_split + video_info = pd.read_csv(split, header=None) + elif (mode == 'val') or (mode == 'test'): + split = 'process_data/data/ucf101/test_split%02d.csv' % self.which_split # use test for val, temporary + video_info = pd.read_csv(split, header=None) + else: raise ValueError('wrong mode') + + # get action list + self.action_dict_encode = {} + self.action_dict_decode = {} + + action_file = os.path.join('process_data/data/ucf101', 'classInd.txt') + action_df = pd.read_csv(action_file, sep=' ', header=None) + for _, row in action_df.iterrows(): + act_id, act_name = row + act_id = int(act_id) - 1 # let id start from 0 + self.action_dict_decode[act_id] = act_name + self.action_dict_encode[act_name] = act_id + + # filter out too short videos: + drop_idx = [] + for idx, row in video_info.iterrows(): + vpath, vlen = row + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: + drop_idx.append(idx) + self.video_info = video_info.drop(drop_idx, axis=0) + + if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) + # shuffle not required + + def idx_sampler(self, vlen, vpath): + '''sample index from a video''' + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] + n = 1 + if self.mode == 'test': + seq_idx_block = np.arange(0, vlen, self.downsample) # all possible frames with downsampling + return [seq_idx_block, vpath] + start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) + seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx + seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample + return [seq_idx_block, vpath] + + + def __getitem__(self, index): + vpath, vlen = self.video_info.iloc[index] + items = self.idx_sampler(vlen, vpath) + if items is None: print(vpath) + + idx_block, vpath = items + if self.mode != 'test': + assert idx_block.shape == (self.num_seq, self.seq_len) + idx_block = idx_block.reshape(self.num_seq*self.seq_len) + + seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] + t_seq = self.transform(seq) # apply same transform + + num_crop = None + try: + (C, H, W) = t_seq[0].size() + t_seq = torch.stack(t_seq, 0) + except: + (C, H, W) = t_seq[0][0].size() + tmp = [torch.stack(i, 0) for i in t_seq] + assert len(tmp) == 5 + num_crop = 5 + t_seq = torch.stack(tmp, 1) + + if self.mode == 'test': + # return all available clips, but cut into length = num_seq + SL = t_seq.size(0) + clips = []; i = 0 + while i+self.seq_len <= SL: + clips.append(t_seq[i:i+self.seq_len, :]) + # i += self.seq_len//2 + i += self.seq_len + if num_crop: + # half overlap: + clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] + NC = len(clips) + t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) + else: + # half overlap: + clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] + t_seq = torch.stack(clips, 0) + else: + t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) + + try: + vname = vpath.split('/')[-3] + vid = self.encode_action(vname) + except: + vname = vpath.split('/')[-2] + vid = self.encode_action(vname) + + label = torch.LongTensor([vid]) + + return t_seq, label + + def __len__(self): + return len(self.video_info) + + def encode_action(self, action_name): + '''give action name, return category''' + return self.action_dict_encode[action_name] + + def decode_action(self, action_code): + '''give action code, return action name''' + return self.action_dict_decode[action_code] + + +class HMDB51_3d(data.Dataset): + def __init__(self, + mode='train', + transform=None, + seq_len=10, + num_seq=1, + downsample=1, + epsilon=5, + which_split=1): + self.mode = mode + self.transform = transform + self.seq_len = seq_len + self.num_seq = num_seq + self.downsample = downsample + self.epsilon = epsilon + self.which_split = which_split + + # splits + if mode == 'train': + split = 'process_data/data/hmdb51/train_split%02d.csv' % self.which_split + video_info = pd.read_csv(split, header=None) + elif (mode == 'val') or (mode == 'test'): + split = 'process_data/data/hmdb51/test_split%02d.csv' % self.which_split # use test for val, temporary + video_info = pd.read_csv(split, header=None) + else: raise ValueError('wrong mode') + + # get action list + self.action_dict_encode = {} + self.action_dict_decode = {} + + action_file = os.path.join('process_data/data/hmdb51', 'classInd.txt') + action_df = pd.read_csv(action_file, sep=' ', header=None) + for _, row in action_df.iterrows(): + act_id, act_name = row + act_id = int(act_id) - 1 # let id start from 0 + self.action_dict_decode[act_id] = act_name + self.action_dict_encode[act_name] = act_id + + # filter out too short videos: + drop_idx = [] + for idx, row in video_info.iterrows(): + vpath, vlen = row + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: + drop_idx.append(idx) + self.video_info = video_info.drop(drop_idx, axis=0) + + if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) + # shuffle not required + + def idx_sampler(self, vlen, vpath): + '''sample index from a video''' + if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] + n = 1 + if self.mode == 'test': + seq_idx_block = np.arange(0, vlen, self.downsample) # all possible frames with downsampling + return [seq_idx_block, vpath] + start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) + seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx + seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample + return [seq_idx_block, vpath] + + + def __getitem__(self, index): + vpath, vlen = self.video_info.iloc[index] + items = self.idx_sampler(vlen, vpath) + if items is None: print(vpath) + + idx_block, vpath = items + if self.mode != 'test': + assert idx_block.shape == (self.num_seq, self.seq_len) + idx_block = idx_block.reshape(self.num_seq*self.seq_len) + + seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] + t_seq = self.transform(seq) # apply same transform + + num_crop = None + try: + (C, H, W) = t_seq[0].size() + t_seq = torch.stack(t_seq, 0) + except: + (C, H, W) = t_seq[0][0].size() + tmp = [torch.stack(i, 0) for i in t_seq] + assert len(tmp) == 5 + num_crop = 5 + t_seq = torch.stack(tmp, 1) + # print(t_seq.size()) + # import ipdb; ipdb.set_trace() + if self.mode == 'test': + # return all available clips, but cut into length = num_seq + SL = t_seq.size(0) + clips = []; i = 0 + while i+self.seq_len <= SL: + clips.append(t_seq[i:i+self.seq_len, :]) + # i += self.seq_len//2 + i += self.seq_len + if num_crop: + # half overlap: + clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] + NC = len(clips) + t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) + else: + # half overlap: + clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,3*self.num_seq//4)] + t_seq = torch.stack(clips, 0) + else: + t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) + + try: + vname = vpath.split('/')[-3] + vid = self.encode_action(vname) + except: + vname = vpath.split('/')[-2] + vid = self.encode_action(vname) + + label = torch.LongTensor([vid]) + + return t_seq, label + + def __len__(self): + return len(self.video_info) + + def encode_action(self, action_name): + '''give action name, return category''' + return self.action_dict_encode[action_name] + + def decode_action(self, action_code): + '''give action code, return action name''' + return self.action_dict_decode[action_code] + diff --git a/video/env.yml b/video/env.yml new file mode 100755 index 0000000..6b85a20 --- /dev/null +++ b/video/env.yml @@ -0,0 +1,69 @@ +name: pdm +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - blas=1.0=mkl + - ca-certificates=2020.12.5=ha878542_0 + - certifi=2020.12.5=py38h578d9bd_1 + - cudatoolkit=11.0.221=h6bb024c_0 + - freetype=2.10.4=h5ab3b9f_0 + - intel-openmp=2020.2=254 + - joblib=1.0.1=pyhd3eb1b0_0 + - jpeg=9b=h024ee3a_2 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libedit=3.1.20191231=h14c3975_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libprotobuf=3.14.0=h8c45485_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_1 + - libuv=1.40.0=h7b6447c_0 + - lz4-c=1.9.3=h2531618_0 + - mkl=2020.2=256 + - mkl-service=2.3.0=py38he904b0f_0 + - mkl_fft=1.2.1=py38h54f3939_0 + - mkl_random=1.1.1=py38h0573a6f_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py38hff7bd54_0 + - numpy=1.19.2=py38h54aff64_0 + - numpy-base=1.19.2=py38hfa32c7d_0 + - olefile=0.46=py_0 + - openssl=1.1.1j=h27cfd23_0 + - pandas=1.2.2=py38ha9443f7_0 + - pillow=8.1.0=py38he98fc37_0 + - pip=21.0.1=py38h06a4308_0 + - protobuf=3.14.0=py38h2531618_1 + - python=3.8.5=h7579374_1 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - python_abi=3.8=1_cp38 + - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 + - pytz=2021.1=pyhd3eb1b0_0 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py38h06a4308_0 + - six=1.15.0=py38h06a4308_0 + - sqlite=3.33.0=h62c20be_0 + - tensorboardx=2.1=py_0 + - tk=8.6.10=hbc83047_0 + - torchaudio=0.7.2=py38 + - torchvision=0.8.2=py38_cu110 + - tqdm=4.56.0=pyhd3eb1b0_0 + - typing_extensions=3.7.4.3=pyha847dfd_0 + - tzdata=2020f=h52ac0ba_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - cycler==0.10.0 + - kiwisolver==1.3.1 + - matplotlib==3.3.4 + - opencv-contrib-python==4.5.1.48 + - opencv-python-headless==4.5.1.48 + - pyparsing==2.4.7 +prefix: /lcncluster/illing/.caas_HOME/miniconda3/envs/pdm + diff --git a/video/env_setup.txt b/video/env_setup.txt new file mode 100755 index 0000000..5f2e262 --- /dev/null +++ b/video/env_setup.txt @@ -0,0 +1,11 @@ +CUDA version 11.0 + +conda create -n pdm python=3 +conda activate pdm +conda install pytorch torchvision torchaudio cudatoolkit=11.0 -c pytorch +conda install joblib pandas tqdm +conda install -c conda-forge tensorboardx +pip install opencv-python-headless +pip install matplotlib + +conda env export > ./env.yml diff --git a/video/finetune.py b/video/finetune.py new file mode 100644 index 0000000..e60f41a --- /dev/null +++ b/video/finetune.py @@ -0,0 +1,372 @@ +# -*- coding: utf-8 -*- +from VGG_8 import VGG_8 +from augmentation import * +from dataset_3d_lc import * +from torch.utils import data +from tqdm import tqdm +from tensorboardX import SummaryWriter +import os +import time +import sys +import json +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import argparse +from utils import AverageMeter + +parser = argparse.ArgumentParser() +parser.add_argument('--seq_len', default=8, type=int, help='number of frames in each sequence') +parser.add_argument('--temp_VGG', action='store_true', help='standard or temporal VGG-8') +parser.add_argument('--mode', default='CPC', help='Self-supervised algorithm, necessary for retrieving saved network structure') +parser.add_argument('--spatial_collapse', action='store_true', help='performing average pooling or not to obtain z') +parser.add_argument('--spatial_segm', action='store_true', help='use of spatial negatives (if not, then flattening)') +parser.add_argument('--single_predictor', action='store_true', help='use of a single recursively applied predictor') +parser.add_argument('--predictor_bias', action='store_true', help='linear predicting layer having bias or not') +parser.add_argument('--monitor_all_layers', action='store_true', help='perform the classification at each layer') +parser.add_argument('--batch_size', default=16, type=int) +parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') +parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') +parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') +parser.add_argument('--gpu', default='0', type=str) +parser.add_argument('--img_dim', default=128, type=int) +parser.add_argument('--name', help='relative path to load trained encoder and store the model and the tensorboard files') + + +# CLASS PERFORMING TOP-K ACCURACY FOR CLASSIFICATION +class top_k(nn.Module): + def __init__(self, k): + super(top_k, self).__init__() + k = [k] if isinstance(k, int) else k + self.k=k + + def forward(self, input, targets): + accs = [] + for k in self.k: + acc_k = torch.mean(torch.tensor([(target == input_line).any().float() for (target, input_line) in zip(targets,torch.topk(input, k, dim=1)[1])])) + accs.append(acc_k) + return accs + + +def classify(): + torch.autograd.set_detect_anomaly(True) + + global args; args = parser.parse_args() + os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) + global cuda; cuda = torch.device('cuda') + + img_path = args.name + # CREATING ENCODER MODEL, LAST ARGUMENT set as True PREVENTS COSTLY COMPUTATION OF SELF-SUPERVISED LOSSES AND ACCS + base_model = VGG_8(args.temp_VGG, args.mode, args.spatial_collapse, args.single_predictor, args.spatial_segm, args.predictor_bias, True) + # IF MODEL FOUND IN FOLDER DESIGNATED BY name, LOAD PARAMETERS + if os.path.isfile(img_path+'/model.pth.tar'): + base_model.load_state_dict(torch.load(img_path+'/model.pth.tar')) + else: + print('file not found, starts with random encoder') + + + # FAKE INPUT TO COMPUTE SIZE OF CLASSIFIERS (AT EACH LAYER OR JUST AT THE END) + input = torch.randn(1,3,args.seq_len,args.img_dim,args.img_dim) + output_sizes = [] + for block in base_model.blocks: + input = block(input.detach()) + if args.monitor_all_layers: + # WE DO NOT COUNT TIME DIMENSION (2) BECAUSE IT IS AVERAGE POOLED + output_sizes.append([int(torch.numel(input)/input.size(2)), input.size(2)]) + + if not args.monitor_all_layers: + output_sizes.append([int(torch.numel(input)/input.size(2)),input.size(2)]) + + + # CREATION OF THE CLASSIFIER(S) FOR EACH OUTPUT SIZE + classifications = nn.ModuleList() + for i, output_size in enumerate(output_sizes): + classifications.append(nn.Sequential(nn.AvgPool3d((output_size[1],1,1)),nn.Flatten(),nn.BatchNorm1d(output_size[0]), nn.Dropout(0.5), nn.Linear(output_size[0], 101))) + classifications[i][2].weight.data.fill_(1) + classifications[i][2].bias.data.zero_() + + for name, param in classifications[i][-1].named_parameters(): + if 'bias' in name: + nn.init.constant_(param, 0.0) + elif 'weight' in name: + nn.init.orthogonal_(param, 1) + + + # MIGRATING MODEL AND CLASSIFIERS TO CUDA + base_model = base_model.to(cuda) + classifications = classifications.to(cuda) + + + # SETTING THE FINETUNING, CLASSIFIER WITH lr AND ENCODER WITH lr/10 + print('=> finetune backbone with smaller lr') + params = [] + for name, param in base_model.named_parameters(): + params.append({'params': param, 'lr': args.lr/10}) + for name, param in classifications.named_parameters(): + params.append({'params': param}) + + + # CHECKING GRADIENTS OF DIFFERENT COMPONENTS + print('\n===========Check Grad============') + for name, param in base_model.named_parameters(): + print(name, param.requires_grad) + for name, param in classifications.named_parameters(): + print(name, param.requires_grad) + print('=================================\n') + + + # GIVE THE PARAMETERS TO THE OPTIMIZER + optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) + lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[60, 80, 100], repeat=1) + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + + + # DUMMY VALIDATION LOSS FOR SAVING BEST MODEL + best_loss = 100 + global iteration; iteration = 0 + + + # DEFINE THE TRANSFORMATIONS FOR TRAIN AND VALIDATION + transform = transforms.Compose([ + RandomSizedCrop(consistent=True, size=224, p=1.0), + Scale(size=(args.img_dim,args.img_dim)), + RandomHorizontalFlip(consistent=True), + ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), + ToTensor(), + Normalize() + ]) + val_transform = transforms.Compose([ + RandomSizedCrop(consistent=True, size=224, p=0.3), + Scale(size=(args.img_dim,args.img_dim)), + RandomHorizontalFlip(consistent=True), + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), + ToTensor(), + Normalize() + ]) + + train_loader = get_data(transform, 'train') + val_loader = get_data(val_transform, 'val') + + appendix = '_finetune' + + + # INSTANTIATION OF THE TENSORBOARD MONITORING + try: # old version + writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'classification'+appendix+'/val')) + writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'classification'+appendix+'/train')) + except: # v1.7 + writer_val = SummaryWriter(logdir=os.path.join(img_path, 'classification'+appendix+'/val')) + writer_train = SummaryWriter(logdir=os.path.join(img_path, 'classification'+appendix+'/train')) + + + for epoch in range(args.epochs): + train_losses, train_accs = train(base_model, classifications, train_loader, optimizer, epoch, args.monitor_all_layers) + val_losses, val_accs = validate(base_model, classifications, val_loader, epoch, args.monitor_all_layers) + + scheduler.step() + + + # SAVE CURVES, ITERATE OVER LOSSES OF THE NETWORK (1 LOSS IF END-TO-END AND N IF PER-LAYER) + for i, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)): + writer_train.add_scalar('global/loss_{}'.format(i), train_loss, epoch) + writer_val.add_scalar('global/loss_{}'.format(i), val_loss, epoch) + + + # SAVE CURVES, ITERATE OVER ACCURACIES OF THE NETWORK ([3] ACCURACIES IF END-TO-END AND N*[3] IF PER-LAYER) + for i, (train_acc, val_acc) in enumerate(zip(train_accs, val_accs)): + for j in range(3): + a= [1,3,5] + writer_train.add_scalar('global/accuracy_{}_top_{}'.format(i,a[j]), train_acc[j], epoch) + writer_val.add_scalar('global/accuracy_{}_top_{}'.format(i, a[j]), val_acc[j], epoch) + + + # SAVE MODEL IF BEST VALIDATION LOSS + if val_losses[-1] <= best_loss: + best_loss = val_loss + torch.save(classifications.state_dict(), img_path+'/classifier'+appendix+'.pth.tar') + + print('epoch {}/{}'.format(epoch, args.epochs)) + + + +def get_data(transform, mode='train'): + print('Loading data for "%s" ...' % mode) + global dataset + dataset = UCF101_3d(mode=mode, + transform=transform, + seq_len=args.seq_len, + num_seq=1, # NUMBER OF SEQUENCES, ARTEFACT FROM DPC CODE, KEEP SET TO 1! + downsample=3) # FRAME RATE DOWNSAMPLING: FPS = 30/downsample + + my_sampler = data.RandomSampler(dataset) + if mode == 'train': + data_loader = data.DataLoader(dataset, + batch_size=args.batch_size, + sampler=my_sampler, + shuffle=False, + num_workers=16, + pin_memory=True, + drop_last=True) + elif mode == 'val': + data_loader = data.DataLoader(dataset, + batch_size=args.batch_size, + sampler=my_sampler, + shuffle=False, + num_workers=16, + pin_memory=True, + drop_last=True) + + print('"%s" dataset size: %d' % (mode, len(dataset))) + return data_loader + + + +def train(model, classifiers, data_loader, optimizer, epoch, monitor_all_layers): + cuda = torch.device('cuda') + # SET THE LOSSES AND ACCURACIES + # WARNING: USING x*[OBJECT] DUPLICATES REFERENCES TO THE SAME OBJECT INSTANCE + # HENCE THE FOR LOOP + losses = [] + accuracies = [] + Losses = [] + Accs = [] + if isinstance(classifiers, nn.ModuleList): + for i in range(len(classifiers)): + losses.append(AverageMeter()) + accuracies.append([AverageMeter(),AverageMeter(),AverageMeter()]) + Losses.append(nn.CrossEntropyLoss()) + Accs.append( top_k([1,3,5])) + + model.train() + classifiers.train() + + for (input_seq, target) in tqdm(data_loader): + # CREATING THE LIST OF NETWORK OUTPUTS AND CLASSIFICATION LOSSES + res_losses = [] + outputs = [] + + input_seq = input_seq.squeeze().to(cuda) + target = target.squeeze().to(cuda) + B = input_seq.size(0) + + + # IF ONLY CLASSIFICATION AT FINAL LAYER + if not monitor_all_layers: + _, _, output = model(input_seq) + outputs.append(output) + else: + output=input_seq + + for block in model.blocks: + # OTHERWISE AT EACH LAYER + output = block(output) + outputs.append(output) + + + # MEASURE THE CLASSIFICATION PERFORMANCE + for output, classifier, Loss, Acc, loss, accuracy in zip(outputs, classifiers, Losses, Accs, losses, accuracies): + # PASS THE OUTPUT(S) TO ITS/THEIR CLASSIFIER + output = classifier(output) + # COMPUTE THE CLASSIFIER'S LOSS AND ACCURACIES + l = Loss(output, target) + res_losses.append(l) + acc = Acc(output, target) + loss.update(l.item(), B) + for j in range(3): + accuracy[j].update(acc[j].item(), B) + + + # BACKWARD AND UPDATE THE LOSS RESULTING FROM THE LAST OUTPUT + optimizer.zero_grad() + res_losses[-1].backward() + optimizer.step() + + + # PRINT PERFORMANCES INDEXES AT EVERY EPOCH + for loss, acc in zip(losses, accuracies): + print('Training loss: {:.4f} | top1: {:.4f} | top3: {:.4f} | top5: {:.4f}'.format(loss.avg, acc[0].avg, acc[1].avg ,acc[2].avg)) + return [loss.local_avg for loss in losses], [[acc[0].avg,acc[1].avg,acc[2].avg] for acc in accuracies] + + + +def validate(model, classifiers, data_loader, epoch, monitor_all_layers): + cuda = torch.device('cuda') + # SET THE LOSSES AND ACCURACIES + # WARNING: USING x*[OBJECT] DUPLICATES REFERENCES TO THE SAME OBJECT INSTANCE + # HENCE THE FOR LOOP + losses = [] + accuracies = [] + Losses = [] + Accs = [] + if isinstance(classifiers, nn.ModuleList): + for i in range(len(classifiers)): + losses.append(AverageMeter()) + accuracies.append([AverageMeter(),AverageMeter(),AverageMeter()]) + Losses.append(nn.CrossEntropyLoss()) + Accs.append( top_k([1,3,5])) + + model.eval() + classifiers.eval() + + for (input_seq, target) in tqdm(data_loader): + # CREATING THE LIST OF NETWORK OUTPUTS AND CLASSIFICATION LOSSES + outputs = [] + input_seq = input_seq.squeeze().to(cuda) + target = target.squeeze().to(cuda) + B = input_seq.size(0) + + + # IF ONLY CLASSIFICATION AT FINAL LAYER + if not monitor_all_layers: + _, _, output = model(input_seq) + outputs.append(output) + else: + output=input_seq + + for block in model.blocks: + # OTHERWISE AT EACH LAYER + output = block(output) + outputs.append(output) + + + # MEASURE THE CLASSIFICATION PERFORMANCE + for output, classifier, Loss, Acc, loss, accuracy in zip(outputs, classifiers, Losses, Accs, losses, accuracies): + # PASS THE OUTPUT(S) TO ITS/THEIR CLASSIFIER + output = classifier(output.detach()) + # COMPUTE THE CLASSIFIER'S LOSS AND ACCURACIES + l = Loss(output, target) + acc = Acc(output, target) + + loss.update(l.item(), B) + for j in range(3): + accuracy[j].update(acc[j].item(), B) + + + # PRINT PERFORMANCES INDEXES AT EVERY EPOCH + for loss, acc in zip(losses, accuracies): + print('Validation loss: {:.4f} | top1: {:.4f} | top3: {:.4f} | top5: {:.4f}'.format(loss.avg, acc[0].avg, acc[1].avg ,acc[2].avg)) + return [loss.local_avg for loss in losses], [[acc[0].avg,acc[1].avg,acc[2].avg] for acc in accuracies] + + + +# USE OF THE SAME LEARNING RATE SCHEDULER AS DPC, SHOULD BE TAKEN AWAY FOR MORE STABLE RESULTS +def MultiStepLR_Restart_Multiplier(epoch, gamma=0.1, step=[10,15,20], repeat=3): + '''return the multipier for LambdaLR, + 0 <= ep < 10: gamma^0 + 10 <= ep < 15: gamma^1 + 15 <= ep < 20: gamma^2 + 20 <= ep < 30: gamma^0 ... repeat 3 cycles and then keep gamma^2''' + max_step = max(step) + effective_epoch = epoch % max_step + if epoch // max_step >= repeat: + exp = len(step) - 1 + else: + exp = len([i for i in step if effective_epoch>=i]) + return gamma ** exp + + +if __name__ == '__main__': + args = sys.argv + + classify() \ No newline at end of file diff --git a/video/process_data/data/ucf101/classInd.txt b/video/process_data/data/ucf101/classInd.txt new file mode 100644 index 0000000..f5db44c --- /dev/null +++ b/video/process_data/data/ucf101/classInd.txt @@ -0,0 +1,101 @@ +1 ApplyEyeMakeup +2 ApplyLipstick +3 Archery +4 BabyCrawling +5 BalanceBeam +6 BandMarching +7 BaseballPitch +8 Basketball +9 BasketballDunk +10 BenchPress +11 Biking +12 Billiards +13 BlowDryHair +14 BlowingCandles +15 BodyWeightSquats +16 Bowling +17 BoxingPunchingBag +18 BoxingSpeedBag +19 BreastStroke +20 BrushingTeeth +21 CleanAndJerk +22 CliffDiving +23 CricketBowling +24 CricketShot +25 CuttingInKitchen +26 Diving +27 Drumming +28 Fencing +29 FieldHockeyPenalty +30 FloorGymnastics +31 FrisbeeCatch +32 FrontCrawl +33 GolfSwing +34 Haircut +35 Hammering +36 HammerThrow +37 HandstandPushups +38 HandstandWalking +39 HeadMassage +40 HighJump +41 HorseRace +42 HorseRiding +43 HulaHoop +44 IceDancing +45 JavelinThrow +46 JugglingBalls +47 JumpingJack +48 JumpRope +49 Kayaking +50 Knitting +51 LongJump +52 Lunges +53 MilitaryParade +54 Mixing +55 MoppingFloor +56 Nunchucks +57 ParallelBars +58 PizzaTossing +59 PlayingCello +60 PlayingDaf +61 PlayingDhol +62 PlayingFlute +63 PlayingGuitar +64 PlayingPiano +65 PlayingSitar +66 PlayingTabla +67 PlayingViolin +68 PoleVault +69 PommelHorse +70 PullUps +71 Punch +72 PushUps +73 Rafting +74 RockClimbingIndoor +75 RopeClimbing +76 Rowing +77 SalsaSpin +78 ShavingBeard +79 Shotput +80 SkateBoarding +81 Skiing +82 Skijet +83 SkyDiving +84 SoccerJuggling +85 SoccerPenalty +86 StillRings +87 SumoWrestling +88 Surfing +89 Swing +90 TableTennisShot +91 TaiChi +92 TennisSwing +93 ThrowDiscus +94 TrampolineJumping +95 Typing +96 UnevenBars +97 VolleyballSpiking +98 WalkingWithDog +99 WallPushups +100 WritingOnBoard +101 YoYo diff --git a/video/process_data/readme.md b/video/process_data/readme.md new file mode 100644 index 0000000..810b8d3 --- /dev/null +++ b/video/process_data/readme.md @@ -0,0 +1,51 @@ +## Process data + +This folder has some tools to process UCF101, HMDB51 and Kinetics400 datasets. + +### 1. Download + +Download the videos from source: +[UCF101 source](https://www.crcv.ucf.edu/data/UCF101.php), +[HMDB51 source](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads), +[Kinetics400 source](https://deepmind.com/research/publications/kinetics-human-action-video-dataset). + +Make sure datasets are stored as follows: + +* UCF101 +``` +{your_path}/UCF101/videos/{action class}/{video name}.avi +{your_path}/UCF101/splits_classification/trainlist{01/02/03}.txt +{your_path}/UCF101/splits_classification/testlist{01/02/03}}.txt +``` +with {your_path} being the same as the one leading to process_data + +* HMDB51 +``` +{your_path}/HMDB51/videos/{action class}/{video name}.avi +{your_path}/HMDB51/split/testTrainMulti_7030_splits/{action class}_test_split{1/2/3}.txt +``` + +* Kinetics400 +``` +{your_path}/Kinetics400/videos/train_split/{action class}/{video name}.mp4 +{your_path}/Kinetics400/videos/val_split/{action class}/{video name}.mp4 +``` +Also keep the downloaded csv files, make sure you have: +``` +{your_path}/Kinetics/kinetics_train/kinetics_train.csv +{your_path}/Kinetics/kinetics_val/kinetics_val.csv +{your_path}/Kinetics/kinetics_test/kinetics_test.csv +``` + +### 2. Extract frames + +From pdm_final folder: `python process_data/src/extract_frame.py`. Video frames will be extracted. + +### 3. Collect all paths into csv + +From pdm_final folder: `python process_data/src/write_csv.py`. csv files will be stored in `data/` directory. + + + + + diff --git a/video/process_data/src/extract_frame.py b/video/process_data/src/extract_frame.py new file mode 100644 index 0000000..f5db999 --- /dev/null +++ b/video/process_data/src/extract_frame.py @@ -0,0 +1,111 @@ +from joblib import delayed, Parallel +import os +import sys +import glob +from tqdm import tqdm +import cv2 +import matplotlib.pyplot as plt +plt.switch_backend('agg') + +def extract_video_opencv(v_path, f_root, dim=240): + '''v_path: single video path; + f_root: root to store frames''' + + v_class = v_path.split('/')[-2] + v_name = os.path.basename(v_path)[0:-4] + out_dir = os.path.join(f_root, v_class, v_name) + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + vidcap = cv2.VideoCapture(v_path) + nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float + height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float + if (width == 0) or (height==0): + print(v_path, 'not successfully loaded, drop ..'); return + new_dim = resize_dim(width, height, dim) + + success, image = vidcap.read() + count = 1 + while success: + image = cv2.resize(image, new_dim, interpolation = cv2.INTER_LINEAR) + cv2.imwrite(os.path.join(out_dir, 'image_%05d.jpg' % count), image, + [cv2.IMWRITE_JPEG_QUALITY, 80])# quality from 0-100, 95 is default, high is good + success, image = vidcap.read() + count += 1 + if nb_frames > count: + print('/'.join(out_dir.split('/')[-2::]), 'NOT extracted successfully: %df/%df' % (count, nb_frames)) + vidcap.release() + +def resize_dim(w, h, target): + '''resize (w, h), such that the smaller side is target, keep the aspect ratio''' + if w >= h: + return (int(target * w / h), int(target)) + else: + return (int(target), int(target * h / w)) + +def main_UCF101(v_root, f_root): + print('extracting UCF101 ... ') + print('extracting videos from %s' % v_root) + print('frame save to %s' % f_root) + + if not os.path.exists(f_root): os.makedirs(f_root) + v_act_root = glob.glob(os.path.join(v_root, '*/')) + for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): + v_paths = glob.glob(os.path.join(j, '*.avi')) + v_paths = sorted(v_paths) + Parallel(n_jobs=8)(delayed(extract_video_opencv)(p.replace('\\','/'), f_root) for p in tqdm(v_paths, total=len(v_paths))) + +def main_HMDB51(v_root, f_root): + print('extracting HMDB51 ... ') + print('extracting videos from %s' % v_root) + print('frame save to %s' % f_root) + + if not os.path.exists(f_root): os.makedirs(f_root) + v_act_root = glob.glob(os.path.join(v_root, '*/')) + for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): + v_paths = glob.glob(os.path.join(j, '*.avi')) + v_paths = sorted(v_paths) + Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) + +def main_kinetics400(v_root, f_root, dim=150): + print('extracting Kinetics400 ... ') + for basename in ['train_split', 'val_split']: + v_root_real = v_root + '/' + basename + if not os.path.exists(v_root_real): + print('Wrong v_root'); sys.exit() + f_root_real = '/scratch/local/ssd/htd/kinetics400/frame_full' + '/' + basename + print('Extract to: \nframe: %s' % f_root_real) + if not os.path.exists(f_root_real): os.makedirs(f_root_real) + v_act_root = glob.glob(os.path.join(v_root_real, '*/')) + v_act_root = sorted(v_act_root) + + # if resume, remember to delete the last video folder + for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): + v_paths = glob.glob(os.path.join(j, '*.mp4')) + v_paths = sorted(v_paths) + # for resume: + v_class = j.split('/')[-2] + out_dir = os.path.join(f_root_real, v_class) + if os.path.exists(out_dir): print(out_dir, 'exists!'); continue + print('extracting: %s' % v_class) + # dim = 150 (crop to 128 later) or 256 (crop to 224 later) + Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root_real, dim=dim) for p in tqdm(v_paths, total=len(v_paths))) + + +if __name__ == '__main__': + # v_root is the video source path, f_root is where to store frames + # edit 'your_path' here: + + main_UCF101(v_root=r'UCF101/videos', + f_root=r'UCF101/frame') + + # main_HMDB51(v_root='your_path/HMDB51/videos', + # f_root='your_path/HMDB51/frame') + + # main_kinetics400(v_root='your_path/Kinetics400/videos', + # f_root='your_path/Kinetics400/frame', dim=150) + + # main_kinetics400(v_root='your_path/Kinetics400_256/videos', + # f_root='your_path/Kinetics400_256/frame', dim=256) diff --git a/video/process_data/src/write_csv.py b/video/process_data/src/write_csv.py new file mode 100644 index 0000000..41b5d69 --- /dev/null +++ b/video/process_data/src/write_csv.py @@ -0,0 +1,115 @@ +import os +import csv +import glob +import pandas as pd +from joblib import delayed, Parallel +from tqdm import tqdm + + +def write_list(data_list, path, ): + with open(path, 'w') as f: + writer = csv.writer(f, delimiter=',') + for row in data_list: + if row: writer.writerow(row) + print('split saved to %s' % path) + +def main_UCF101(f_root, splits_root, csv_root='../data/ucf101/'): + '''generate training/testing split, count number of available frames, save in csv''' + if not os.path.exists(csv_root): os.makedirs(csv_root) + for which_split in [1,2,3]: + train_set = [] + test_set = [] + train_split_file = os.path.join(splits_root, 'trainlist%02d.txt' % which_split) + with open(train_split_file, 'r') as f: + for line in f: + vpath = os.path.join(f_root, line.split(' ')[0][0:-4]) + '/' + train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) + + test_split_file = os.path.join(splits_root, 'testlist%02d.txt' % which_split) + with open(test_split_file, 'r') as f: + for line in f: + vpath = os.path.join(f_root, line.rstrip()[0:-4]) + '/' + test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) + + write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) + write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) + + +def main_HMDB51(f_root, splits_root, csv_root='../data/hmdb51/'): + '''generate training/testing split, count number of available frames, save in csv''' + if not os.path.exists(csv_root): os.makedirs(csv_root) + for which_split in [1,2,3]: + train_set = [] + test_set = [] + split_files = sorted(glob.glob(os.path.join(splits_root, '*_test_split%d.txt' % which_split))) + assert len(split_files) == 51 + for split_file in split_files: + action_name = os.path.basename(split_file)[0:-16] + with open(split_file, 'r') as f: + for line in f: + video_name = line.split(' ')[0] + _type = line.split(' ')[1] + vpath = os.path.join(f_root, action_name, video_name[0:-4]) + '/' + if _type == '1': + train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) + elif _type == '2': + test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) + + write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) + write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) + +### For Kinetics ### +def get_split(root, split_path, mode): + print('processing %s split ...' % mode) + print('checking %s' % root) + split_list = [] + split_content = pd.read_csv(split_path).iloc[:,0:4] + split_list = Parallel(n_jobs=64)\ + (delayed(check_exists)(row, root) \ + for i, row in tqdm(split_content.iterrows(), total=len(split_content))) + return split_list + +def check_exists(row, root): + dirname = '_'.join([row['youtube_id'], '%06d' % row['time_start'], '%06d' % row['time_end']]) + full_dirname = os.path.join(root, row['label'], dirname) + if os.path.exists(full_dirname): + n_frames = len(glob.glob(os.path.join(full_dirname, '*.jpg'))) + return [full_dirname, n_frames] + else: + return None + +def main_Kinetics400(mode, k400_path, f_root, csv_root='../data/kinetics400'): + train_split_path = os.path.join(k400_path, 'kinetics_train/kinetics_train.csv') + val_split_path = os.path.join(k400_path, 'kinetics_val/kinetics_val.csv') + test_split_path = os.path.join(k400_path, 'kinetics_test/kinetics_test.csv') + if not os.path.exists(csv_root): os.makedirs(csv_root) + if mode == 'train': + train_split = get_split(os.path.join(f_root, 'train_split'), train_split_path, 'train') + write_list(train_split, os.path.join(csv_root, 'train_split.csv')) + elif mode == 'val': + val_split = get_split(os.path.join(f_root, 'val_split'), val_split_path, 'val') + write_list(val_split, os.path.join(csv_root, 'val_split.csv')) + elif mode == 'test': + test_split = get_split(f_root, test_split_path, 'test') + write_list(test_split, os.path.join(csv_root, 'test_split.csv')) + else: + raise IOError('wrong mode') + +if __name__ == '__main__': + # f_root is the frame path + # edit 'your_path' here: + + main_UCF101(f_root=r'UCF101/frame/', + splits_root=r'UCF101/splits_classification') + + # main_HMDB51(f_root='your_path/HMDB51/frame', + # splits_root='your_path/HMDB51/split/testTrainMulti_7030_splits') + + # main_Kinetics400(mode='train', # train or val or test + # k400_path='your_path/Kinetics', + # f_root='your_path/Kinetics400/frame') + + # main_Kinetics400(mode='train', # train or val or test + # k400_path='your_path/Kinetics', + # f_root='your_path/Kinetics400_256/frame', + # csv_root='../data/kinetics400_256') diff --git a/video/requirements.yml b/video/requirements.yml new file mode 100755 index 0000000..65362e3 --- /dev/null +++ b/video/requirements.yml @@ -0,0 +1,215 @@ +name: pdm +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - alabaster=0.7.12=py_0 + - argh=0.26.2=py38_0 + - argon2-cffi=20.1.0=py38he774522_1 + - astroid=2.4.2=py38_0 + - async_generator=1.10=py_0 + - atomicwrites=1.4.0=py_0 + - attrs=20.3.0=pyhd3eb1b0_0 + - autopep8=1.5.4=py_0 + - babel=2.9.0=pyhd3eb1b0_0 + - backcall=0.2.0=py_0 + - bcrypt=3.2.0=py38he774522_0 + - blas=1.0=mkl + - bleach=3.2.1=py_0 + - brotlipy=0.7.0=py38h2bbff1b_1003 + - ca-certificates=2020.12.5=h5b45459_0 + - certifi=2020.12.5=py38haa244fe_0 + - cffi=1.14.4=py38hcd4344a_0 + - chardet=3.0.4=py38haa95532_1003 + - cloudpickle=1.6.0=py_0 + - colorama=0.4.4=py_0 + - cryptography=3.2.1=py38hcd4344a_1 + - cudatoolkit=11.0.221=h74a9793_0 + - cycler=0.10.0=py38_0 + - decorator=4.4.2=py_0 + - defusedxml=0.6.0=py_0 + - diff-match-patch=20200713=py_0 + - docutils=0.16=py38_1 + - entrypoints=0.3=py38_0 + - flake8=3.8.4=py_0 + - freetype=2.10.4=hd328e21_0 + - future=0.18.2=py38_1 + - hdf5=1.10.4=h7ebc959_0 + - icc_rt=2019.0.0=h0cc432a_1 + - icu=58.2=ha925a31_3 + - idna=2.10=py_0 + - imagesize=1.2.0=py_0 + - importlib-metadata=2.0.0=py_1 + - importlib_metadata=2.0.0=1 + - intel-openmp=2020.2=254 + - intervaltree=3.1.0=py_0 + - ipdb=0.13.4=pyhd3deb0d_0 + - ipykernel=5.3.4=py38h5ca1d4c_0 + - ipython=7.19.0=py38hd4e2768_0 + - ipython_genutils=0.2.0=pyhd3eb1b0_1 + - ipywidgets=7.5.1=py_1 + - isort=5.6.4=py_0 + - jedi=0.17.1=py38_0 + - jinja2=2.11.2=py_0 + - joblib=0.17.0=py_0 + - jpeg=9b=hb83a4c4_2 + - jsonschema=3.2.0=py_2 + - jupyter=1.0.0=py38_7 + - jupyter_client=6.1.7=py_0 + - jupyter_console=6.2.0=py_0 + - jupyter_core=4.7.0=py38haa95532_0 + - jupyterlab_pygments=0.1.2=py_0 + - keyring=21.4.0=py38_1 + - kiwisolver=1.3.0=py38hd77b12b_0 + - lazy-object-proxy=1.4.3=py38he774522_0 + - libopencv=4.0.1=hbb9e17c_0 + - libpng=1.6.37=h2a8f88b_0 + - libprotobuf=3.14.0=h7755175_0 + - libsodium=1.0.18=h62dcd97_0 + - libspatialindex=1.9.3=h33f27b4_0 + - libtiff=4.1.0=h56a325e_1 + - libuv=1.40.0=he774522_0 + - lz4-c=1.9.2=hf4a77e7_3 + - m2w64-gcc-libgfortran=5.3.0=6 + - m2w64-gcc-libs=5.3.0=7 + - m2w64-gcc-libs-core=5.3.0=7 + - m2w64-gmp=6.1.0=2 + - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 + - markupsafe=1.1.1=py38he774522_0 + - matplotlib=3.3.2=0 + - matplotlib-base=3.3.2=py38hba9282a_0 + - mccabe=0.6.1=py38_1 + - mistune=0.8.4=py38he774522_1000 + - mkl=2020.2=256 + - mkl-service=2.3.0=py38h196d8e1_0 + - mkl_fft=1.2.0=py38h45dec08_0 + - mkl_random=1.1.1=py38h47e9c7a_0 + - msys2-conda-epoch=20160418=1 + - nbclient=0.5.1=py_0 + - nbconvert=6.0.7=py38_0 + - nbformat=5.0.8=py_0 + - nest-asyncio=1.4.3=pyhd3eb1b0_0 + - ninja=1.10.2=py38h6d14046_0 + - notebook=6.1.4=py38_0 + - numpy=1.19.2=py38hadc3359_0 + - numpy-base=1.19.2=py38ha3acd2a_0 + - numpydoc=1.1.0=pyhd3eb1b0_1 + - olefile=0.46=py_0 + - opencv=4.0.1=py38h2a7c758_0 + - openssl=1.1.1i=h8ffe710_0 + - packaging=20.7=pyhd3eb1b0_0 + - pandas=1.1.3=py38ha925a31_0 + - pandoc=2.11=h9490d1a_0 + - pandocfilters=1.4.3=py38haa95532_1 + - paramiko=2.7.2=py_0 + - parso=0.7.0=py_0 + - pathtools=0.1.2=py_1 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pillow=8.0.1=py38h4fa10fc_0 + - pip=20.3.1=py38haa95532_0 + - pluggy=0.13.1=py38_0 + - prometheus_client=0.9.0=pyhd3eb1b0_0 + - prompt-toolkit=3.0.8=py_0 + - prompt_toolkit=3.0.8=0 + - psutil=5.7.2=py38he774522_0 + - ptyprocess=0.6.0=pyhd3eb1b0_2 + - py-opencv=4.0.1=py38he44ac1e_0 + - pycodestyle=2.6.0=py_0 + - pycparser=2.20=py_2 + - pydocstyle=5.1.1=py_0 + - pyflakes=2.2.0=py_0 + - pygments=2.7.3=pyhd3eb1b0_0 + - pylint=2.6.0=py38_0 + - pynacl=1.4.0=py38h62dcd97_1 + - pyopenssl=20.0.0=pyhd3eb1b0_1 + - pyparsing=2.4.7=py_0 + - pyqt=5.9.2=py38ha925a31_4 + - pyrsistent=0.17.3=py38he774522_0 + - pysocks=1.7.1=py38haa95532_0 + - python=3.8.5=h5fd99cc_1 + - python-dateutil=2.8.1=py_0 + - python-jsonrpc-server=0.4.0=py_0 + - python-language-server=0.35.1=py_0 + - python_abi=3.8=1_cp38 + - pytorch=1.7.0=py3.8_cuda110_cudnn8_0 + - pytz=2020.4=pyhd3eb1b0_0 + - pywin32=227=py38he774522_1 + - pywin32-ctypes=0.2.0=py38_1000 + - pywinpty=0.5.7=py38_0 + - pyyaml=5.3.1=py38he774522_1 + - pyzmq=20.0.0=py38hd77b12b_1 + - qdarkstyle=2.8.1=py_0 + - qt=5.9.7=vc14h73c81de_0 + - qtawesome=1.0.1=py_0 + - qtconsole=4.7.7=py_0 + - qtpy=1.9.0=py_0 + - requests=2.25.0=pyhd3eb1b0_0 + - rope=0.18.0=py_0 + - rtree=0.9.4=py38h21ff451_1 + - send2trash=1.5.0=py38_0 + - setuptools=51.0.0=py38haa95532_2 + - sip=4.19.13=py38ha925a31_0 + - six=1.15.0=py38haa95532_0 + - snowballstemmer=2.0.0=py_0 + - sortedcontainers=2.3.0=pyhd3eb1b0_0 + - sphinx=3.2.1=py_0 + - sphinxcontrib-applehelp=1.0.2=py_0 + - sphinxcontrib-devhelp=1.0.2=py_0 + - sphinxcontrib-htmlhelp=1.0.3=py_0 + - sphinxcontrib-jsmath=1.0.1=py_0 + - sphinxcontrib-qthelp=1.0.3=py_0 + - sphinxcontrib-serializinghtml=1.1.4=py_0 + - spyder=4.1.5=py38_0 + - spyder-kernels=1.9.4=py38_0 + - sqlite=3.33.0=h2a8f88b_0 + - tensorboardx=2.1=py_0 + - terminado=0.9.1=py38_0 + - testpath=0.4.4=py_0 + - tk=8.6.10=he774522_0 + - toml=0.10.1=py_0 + - torchaudio=0.7.0=py38 + - torchvision=0.8.1=py38_cu110 + - tornado=6.1=py38h2bbff1b_0 + - tqdm=4.54.1=pyhd3eb1b0_0 + - traitlets=5.0.5=py_0 + - typing_extensions=3.7.4.3=py_0 + - ujson=4.0.1=py38ha925a31_0 + - urllib3=1.25.11=py_0 + - vc=14.2=h21ff451_1 + - vs2015_runtime=14.27.29016=h5e58377_2 + - watchdog=0.10.4=py38haa95532_0 + - wcwidth=0.2.5=py_0 + - webencodings=0.5.1=py38_1 + - wheel=0.36.1=pyhd3eb1b0_0 + - widgetsnbextension=3.5.1=py38_0 + - win_inet_pton=1.1.0=py38haa95532_0 + - wincertstore=0.2=py38_0 + - winpty=0.4.3=4 + - wrapt=1.11.2=py38he774522_0 + - xz=5.2.5=h62dcd97_0 + - yaml=0.2.5=he774522_0 + - yapf=0.30.0=py_0 + - zeromq=4.3.3=ha925a31_3 + - zipp=3.4.0=pyhd3eb1b0_0 + - zlib=1.2.11=h62dcd97_4 + - zstd=1.4.5=h04227a9_0 + - pip: + - absl-py==0.11.0 + - cachetools==4.2.0 + - google-auth==1.24.0 + - google-auth-oauthlib==0.4.2 + - grpcio==1.34.0 + - markdown==3.3.3 + - oauthlib==3.1.0 + - protobuf==3.14.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - requests-oauthlib==1.3.0 + - rsa==4.6 + - tensorboard==2.4.0 + - tensorboard-plugin-wit==1.7.0 + - torchsummary==1.5.1 + - werkzeug==1.0.1 +prefix: C:\Users\Jean\anaconda3\envs\pdm diff --git a/video/run.py b/video/run.py new file mode 100644 index 0000000..d5c6be0 --- /dev/null +++ b/video/run.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + + +from VGG_8 import VGG_8 +from augmentation import * +from train import train, validate +from dataset_3d import * +from torch.utils import data + +from tensorboardX import SummaryWriter +import os +import time +import sys +import json +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import argparse + + +parser = argparse.ArgumentParser() +parser.add_argument('--seq_len', default=8, type=int, help='number of frames in each sequence') +parser.add_argument('--temp_VGG', action='store_true', help='standard or temporal VGG-8') +parser.add_argument('--mode', default='CPC', help='Self-supervised algorithm') +parser.add_argument('--spatial_collapse', action='store_true', help='performing average pooling or not to obtain z') +parser.add_argument('--spatial_segm', action='store_true', help='use of spatial negatives (if not, then flattening)') +parser.add_argument('--single_predictor', action='store_true', help='use of a single recursively applied predictor') +parser.add_argument('--predictor_bias', action='store_true', help='linear predicting layer having bias or not') +parser.add_argument('--batch_size', default=16, type=int) +parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') +parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') +parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') +parser.add_argument('--gpu', default='0', type=str) +parser.add_argument('--img_dim', default=128, type=int) +parser.add_argument('--name', help='relative path to store the model and the tensorboard files') + +def run(): + torch.autograd.set_detect_anomaly(True) + + global args; args = parser.parse_args() + os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) + global cuda; cuda = torch.device('cuda') + + img_path = args.name + + # CREATION OF THE NETWORK AND ITS LOSS + model = VGG_8(args.temp_VGG, args.mode, args.spatial_collapse, args.single_predictor, args.spatial_segm, args.predictor_bias) + + model = model.to(cuda) + + + print('\n===========Check Grad============') + for name, param in model.named_parameters(): + print(name, param.requires_grad) + print('=================================\n') + + params = model.parameters() + optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) + + + # DUMMY VALIDATION LOSS FOR SAVING BEST MODEL + best_loss = 100 + global iteration; iteration = 0 + + + # TRANSFORMATION FOR SELF-SUPERVISED TRAINING + # ARGUMENT consistent SETS IF PROCESSING PER FRAME OR PER SEQUENCE + transform = transforms.Compose([ + RandomHorizontalFlip(consistent=True), + RandomCrop(size=224, consistent=True), + Scale(size=(args.img_dim,args.img_dim)), + RandomGray(consistent=False, p=0.5), + ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0, consistent=False), + ToTensor(), + Normalize() + ]) + + + # CREATION OF DATALOADERS + train_loader = get_data(transform, 'train') + val_loader = get_data(transform, 'val') + + + # INSTANTIATION OF THE TENSORBOARD MONITORING + try: # old version + writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val')) + writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train')) + except: # v1.7 + writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val')) + writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train')) + + + for epoch in range(args.epochs): + + train_losses, train_accs = train(model, train_loader, optimizer, epoch) + val_losses, val_accs = validate(model, val_loader, epoch) + + + # SAVE CURVES, ITERATE OVER LOSSES OF THE NETWORK (1 LOSS IF END-TO-END AND N IF PER-LAYER) + for i, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)): + writer_train.add_scalar('global/loss_{}'.format(i), train_loss, epoch) + writer_val.add_scalar('global/loss_{}'.format(i), val_loss, epoch) + + + # SAVE CURVES, ITERATE OVER ACCURACIES OF THE NETWORK ([3] ACCURACIES IF END-TO-END AND N*[3] IF PER-LAYER) + for i, (train_acc, val_acc) in enumerate(zip(train_accs, val_accs)): + # EACH LOSS IS ASSOCIATED WITH TOP-1,3,5 ACCURACIES + for j in range(3): + a= [1,3,5] + writer_train.add_scalar('global/accuracy_{}_top_{}'.format(i,a[j]), train_acc[j], epoch) + writer_val.add_scalar('global/accuracy_{}_top_{}'.format(i, a[j]), val_acc[j], epoch) + + + # SAVE MODEL IF BEST VALIDATION LOSS + if val_losses[-1] <= best_loss: + best_loss = val_losses[-1] + torch.save(model.state_dict(), img_path+'/model.pth.tar') + + print('epoch {}/{}'.format(epoch, args.epochs)) + + + +def get_data(transform, mode='train'): + print('Loading data for "%s" ...' % mode) + + dataset = UCF101_3d(mode=mode, + transform=transform, + seq_len=args.seq_len, + num_seq=1, # NUMBER OF SEQUENCES, ARTEFACT FROM DPC CODE, KEEP SET TO 1! + downsample=3) # FRAME RATE DOWNSAMPLING: FPS = 30/downsample + sampler = data.RandomSampler(dataset) + + data_loader = data.DataLoader(dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=8, + pin_memory=True, + drop_last=True) + + print('"%s" dataset size: %d' % (mode, len(dataset))) + + return data_loader + +if __name__ == '__main__': + args = sys.argv + + run() \ No newline at end of file diff --git a/video/test.py b/video/test.py new file mode 100644 index 0000000..c72b4c3 --- /dev/null +++ b/video/test.py @@ -0,0 +1,366 @@ +# -*- coding: utf-8 -*- + +from VGG_8 import VGG_8 +from augmentation import * +from dataset_3d_lc import * +from torch.utils import data +from tqdm import tqdm +from tensorboardX import SummaryWriter +import os +import sys +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import argparse +from utils import AverageMeter + + +parser = argparse.ArgumentParser() +parser.add_argument('--seq_len', default=8, type=int, help='number of frames in each sequence') +parser.add_argument('--temp_VGG', action='store_true', help='standard or temporal VGG-8') +parser.add_argument('--mode', default='CPC', help='Self-supervised algorithm, necessary for retrieving saved network structure') +parser.add_argument('--spatial_collapse', action='store_true', help='performing average pooling or not to obtain z') +parser.add_argument('--spatial_segm', action='store_true', help='use of spatial negatives (if not, then flattening)') +parser.add_argument('--single_predictor', action='store_true', help='use of a single recursively applied predictor') +parser.add_argument('--predictor_bias', action='store_true', help='linear predicting layer having bias or not') +parser.add_argument('--monitor_all_layers', action='store_true', help='perform the classification at each layer') +parser.add_argument('--batch_size', default=16, type=int) +parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') +parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') +parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') +parser.add_argument('--gpu', default='0', type=str) +parser.add_argument('--img_dim', default=128, type=int) +parser.add_argument('--name', help='relative path to load trained encoder and store the model and the tensorboard files') + + +# CLASS PERFORMING TOP-K ACCURACY FOR CLASSIFICATION +class top_k(nn.Module): + def __init__(self, k): + super(top_k, self).__init__() + k = [k] if isinstance(k, int) else k + self.k=k + + def forward(self, input, targets): + accs = [] + for k in self.k: + acc_k = torch.mean(torch.tensor([(target == input_line).any().float() for (target, input_line) in zip(targets,torch.topk(input, k, dim=1)[1])])) + accs.append(acc_k) + return accs + + +def classify(): + torch.autograd.set_detect_anomaly(True) + + global args; args = parser.parse_args() + os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) + global cuda; cuda = torch.device('cuda') + + img_path = args.name + # CREATING ENCODER MODEL, LAST ARGUMENT set as True PREVENTS COSTLY COMPUTATION OF SELF-SUPERVISED LOSSES AND ACCS + base_model = VGG_8(args.temp_VGG, args.mode, args.spatial_collapse, args.single_predictor, args.spatial_segm, args.predictor_bias, True) + # IF MODEL FOUND IN FOLDER DESIGNATED BY name, LOAD PARAMETERS + if os.path.isfile(img_path+'/model.pth.tar'): + base_model.load_state_dict(torch.load(img_path+'/model.pth.tar')) + else: + print('file not found, starts with random encoder') + + + # FREEZE PARAMETERS + for param in base_model.parameters(): + param.requires_grad = False + base_model.eval() # ADDITIONAL PRECAUTION + FREEZES THE MEAN AND VAR FROM BATCHNORMS + + + # FAKE INPUT TO COMPUTE SIZE OF CLASSIFIERS (AT EACH LAYER OR JUST AT THE END) + input = torch.randn(1,3,args.seq_len,args.img_dim,args.img_dim) + output_sizes = [] + for block in base_model.blocks: + input = block(input.detach()) + if args.monitor_all_layers: + # WE DO NOT COUNT TIME DIMENSION (2) BECAUSE IT IS AVERAGE POOLED + output_sizes.append([int(torch.numel(input)/input.size(2)), input.size(2)]) + + if not args.monitor_all_layers: + output_sizes.append([int(torch.numel(input)/input.size(2)),input.size(2)]) + + + # CREATION OF THE CLASSIFIER(S) FOR EACH OUTPUT SIZE + classifications = nn.ModuleList() + for i, output_size in enumerate(output_sizes): + classifications.append(nn.Sequential(nn.AvgPool3d((output_size[1],1,1)),nn.Flatten(),nn.BatchNorm1d(output_size[0]), nn.Dropout(0.5), nn.Linear(output_size[0], 101))) + classifications[i][2].weight.data.fill_(1) + classifications[i][2].bias.data.zero_() + + for name, param in classifications[i][-1].named_parameters(): + if 'bias' in name: + nn.init.constant_(param, 0.0) + elif 'weight' in name: + nn.init.orthogonal_(param, 1) + + + # MIGRATING MODEL AND CLASSIFIERS TO CUDA + base_model = base_model.to(cuda) + classifications = classifications.to(cuda) + + + # CHECKING GRADIENTS OF DIFFERENT COMPONENTS + print('\n===========Check Grad============') + for name, param in base_model.named_parameters(): + print(name, param.requires_grad) + for name, param in classifications.named_parameters(): + print(name, param.requires_grad) + print('=================================\n') + + + # GIVE THE CLASSIFIERS' PARAMETERS TO THE OPTIMIZER + params = classifications.parameters() + optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) + lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[60, 80, 100], repeat=1) + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + + + # DUMMY VALIDATION LOSS FOR SAVING BEST MODEL + best_loss = 100 + global iteration; iteration = 0 + + + # DEFINE THE TRANSFORMATIONS FOR TRAIN AND VALIDATION + transform = transforms.Compose([ + RandomSizedCrop(consistent=True, size=224, p=1.0), + Scale(size=(args.img_dim,args.img_dim)), + RandomHorizontalFlip(consistent=True), + ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), + ToTensor(), + Normalize() + ]) + val_transform = transforms.Compose([ + RandomSizedCrop(consistent=True, size=224, p=0.3), + Scale(size=(args.img_dim,args.img_dim)), + RandomHorizontalFlip(consistent=True), + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), + ToTensor(), + Normalize() + ]) + + train_loader = get_data(transform, 'train') + val_loader = get_data(val_transform, 'val') + + + # SPECIFY IF ALL LAYERS MONITORED + if args.monitor_all_layers: + appendix = '_all_layers' + else: + appendix = '' + + + # INSTANTIATION OF THE TENSORBOARD MONITORING + try: # old version + writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'classification'+appendix+'/val')) + writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'classification'+appendix+'/train')) + except: # v1.7 + writer_val = SummaryWriter(logdir=os.path.join(img_path, 'classification'+appendix+'/val')) + writer_train = SummaryWriter(logdir=os.path.join(img_path, 'classification'+appendix+'/train')) + + + for epoch in range(args.epochs): + + train_losses, train_accs = train(base_model, classifications, train_loader, optimizer, epoch, args.monitor_all_layers) + val_losses, val_accs = validate(base_model, classifications, val_loader, epoch, args.monitor_all_layers) + + scheduler.step() + + # SAVE CURVES, ITERATE OVER LOSSES OF THE NETWORK (1 LOSS IF END-TO-END AND N IF PER-LAYER) + for i, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses)): + writer_train.add_scalar('global/loss_{}'.format(i), train_loss, epoch) + writer_val.add_scalar('global/loss_{}'.format(i), val_loss, epoch) + + # SAVE CURVES, ITERATE OVER ACCURACIES OF THE NETWORK ([3] ACCURACIES IF END-TO-END AND N*[3] IF PER-LAYER) + for i, (train_acc, val_acc) in enumerate(zip(train_accs, val_accs)): + for j in range(3): + a= [1,3,5] + writer_train.add_scalar('global/accuracy_{}_top_{}'.format(i,a[j]), train_acc[j], epoch) + writer_val.add_scalar('global/accuracy_{}_top_{}'.format(i, a[j]), val_acc[j], epoch) + + # SAVE MODEL IF BEST VALIDATION LOSS + if val_losses[-1] <= best_loss: + best_loss = val_loss + torch.save(classifications.state_dict(), img_path+'/classifier'+appendix+'.pth.tar') + + print('epoch {}/{}'.format(epoch, args.epochs)) + + + +def get_data(transform, mode='train'): + print('Loading data for "%s" ...' % mode) + global dataset + dataset = UCF101_3d(mode=mode, + transform=transform, + seq_len=args.seq_len, + num_seq=1, # NUMBER OF SEQUENCES, ARTEFACT FROM DPC CODE, KEEP SET TO 1! + downsample=3) # FRAME RATE DOWNSAMPLING: FPS = 30/downsample + + my_sampler = data.RandomSampler(dataset) + if mode == 'train': + data_loader = data.DataLoader(dataset, + batch_size=args.batch_size, + sampler=my_sampler, + shuffle=False, + num_workers=16, + pin_memory=True, + drop_last=True) + elif mode == 'val': + data_loader = data.DataLoader(dataset, + batch_size=args.batch_size, + sampler=my_sampler, + shuffle=False, + num_workers=16, + pin_memory=True, + drop_last=True) + + print('"%s" dataset size: %d' % (mode, len(dataset))) + return data_loader + + + +def train(model, classifiers, data_loader, optimizer, epoch, monitor_all_layers): + cuda = torch.device('cuda') + # SET THE LOSSES AND ACCURACIES + # WARNING: USING x*[OBJECT] DUPLICATES REFERENCES TO THE SAME OBJECT INSTANCE + # HENCE THE FOR LOOP + losses = [] + accuracies = [] + Losses = [] + Accs = [] + if isinstance(classifiers, nn.ModuleList): + for i in range(len(classifiers)): + losses.append(AverageMeter()) + accuracies.append([AverageMeter(),AverageMeter(),AverageMeter()]) + Losses.append(nn.CrossEntropyLoss()) + Accs.append( top_k([1,3,5])) + + model.eval() + classifiers.train() + + for (input_seq, target) in tqdm(data_loader): + # CREATING THE LIST OF NETWORK OUTPUTS AND CLASSIFICATION LOSSES + res_losses = [] + outputs = [] + + input_seq = input_seq.squeeze().to(cuda) + target = target.squeeze().to(cuda) + B = input_seq.size(0) + + # IF ONLY CLASSIFICATION AT FINAL LAYER + if not monitor_all_layers: + _, _, output = model(input_seq) + outputs.append(output) + else: + output=input_seq + # OTHERWISE AT EACH LAYER + for block in model.blocks: + output = block(output) + outputs.append(output) + + + # MEASURE THE CLASSIFICATION PERFORMANCE + for output, classifier, Loss, Acc, loss, accuracy in zip(outputs, classifiers, Losses, Accs, losses, accuracies): + # PASS THE OUTPUT(S) TO ITS/THEIR CLASSIFIER + output = classifier(output.detach()) + # COMPUTE THE CLASSIFIER'S LOSS AND ACCURACIES + l = Loss(output, target) + res_losses.append(l) + acc = Acc(output, target) + loss.update(l.item(), B) + for j in range(3): + accuracy[j].update(acc[j].item(), B) + + # ITERATE OVER LOSS OF EVERY CLASSIFIER, BACKWARD AND UPDATE + optimizer.zero_grad() + for l in res_losses: + l.backward() + optimizer.step() + + # PRINT PERFORMANCES INDEXES AT EVERY EPOCH + for loss, acc in zip(losses, accuracies): + print('Training loss: {:.4f} | top1: {:.4f} | top3: {:.4f} | top5: {:.4f}'.format(loss.avg, acc[0].avg, acc[1].avg ,acc[2].avg)) + return [loss.local_avg for loss in losses], [[acc[0].avg,acc[1].avg,acc[2].avg] for acc in accuracies] + + + +def validate(model, classifiers, data_loader, epoch, monitor_all_layers): + cuda = torch.device('cuda') + # SET THE LOSSES AND ACCURACIES + # WARNING: USING x*[OBJECT] DUPLICATES REFERENCES TO THE SAME OBJECT INSTANCE + # HENCE THE FOR LOOP + losses = [] + accuracies = [] + Losses = [] + Accs = [] + if isinstance(classifiers, nn.ModuleList): + for i in range(len(classifiers)): + losses.append(AverageMeter()) + accuracies.append([AverageMeter(),AverageMeter(),AverageMeter()]) + Losses.append(nn.CrossEntropyLoss()) + Accs.append( top_k([1,3,5])) + + model.eval() + classifiers.eval() + for (input_seq, target) in tqdm(data_loader): + # CREATING THE LIST OF NETWORK OUTPUTS + outputs = [] + input_seq = input_seq.squeeze().to(cuda) + target = target.squeeze().to(cuda) + B = input_seq.size(0) + + + # IF ONLY CLASSIFICATION AT FINAL LAYER + if not monitor_all_layers: + _, _, output = model(input_seq) + outputs.append(output) + else: + output=input_seq + # OTHERWISE AT EACH LAYER + for block in model.blocks: + output = block(output) + outputs.append(output) + + + # MEASURE THE CLASSIFICATION PERFORMANCE + for output, classifier, Loss, Acc, loss, accuracy in zip(outputs, classifiers, Losses, Accs, losses, accuracies): + # PASS THE OUTPUT(S) TO ITS/THEIR CLASSIFIER + output = classifier(output.detach()) + # COMPUTE THE CLASSIFIER'S LOSS AND ACCURACIES + l = Loss(output, target) + acc = Acc(output, target) + loss.update(l.item(), B) + for j in range(3): + accuracy[j].update(acc[j].item(), B) + + + # PRINT PERFORMANCES INDEXES AT EVERY EPOCH + for loss, acc in zip(losses, accuracies): + print('Validation loss: {:.4f} | top1: {:.4f} | top3: {:.4f} | top5: {:.4f}'.format(loss.avg, acc[0].avg, acc[1].avg ,acc[2].avg)) + return [loss.local_avg for loss in losses], [[acc[0].avg,acc[1].avg,acc[2].avg] for acc in accuracies] + + +# USE OF THE SAME LEARNING RATE SCHEDULER AS DPC, SHOULD BE TAKEN AWAY FOR MORE STABLE RESULTS +def MultiStepLR_Restart_Multiplier(epoch, gamma=0.1, step=[10,15,20], repeat=3): + '''return the multipier for LambdaLR, + 0 <= ep < 10: gamma^0 + 10 <= ep < 15: gamma^1 + 15 <= ep < 20: gamma^2 + 20 <= ep < 30: gamma^0 ... repeat 3 cycles and then keep gamma^2''' + max_step = max(step) + effective_epoch = epoch % max_step + if epoch // max_step >= repeat: + exp = len(step) - 1 + else: + exp = len([i for i in step if effective_epoch>=i]) + return gamma ** exp + + +if __name__ == '__main__': + args = sys.argv + + classify() \ No newline at end of file diff --git a/video/train.py b/video/train.py new file mode 100644 index 0000000..8d2de8e --- /dev/null +++ b/video/train.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +from tqdm import tqdm +import torch +from utils import AverageMeter + + +def train(model, data_loader, optimizer, epoch): + cuda = torch.device('cuda') + # SET THE LOSSES AND ACCURACIES + # WARNING: USING x*[OBJECT] DUPLICATES REFERENCES TO THE SAME OBJECT INSTANCE + # HENCE THE FOR LOOP + losses = [] + accuracies = [] + for i in range( model.get_nb_losses()): + losses.append(AverageMeter()) + accuracies.append([AverageMeter(),AverageMeter(),AverageMeter()]) + + model.train() + + for input_seq in tqdm(data_loader): + # INPUT HAS DIMENSION (B, 1, 3, T, X, Y) + # 1 COMES FROM DPC TREATING SEQUENCES OF SEQUENCES,SOLVED WITH SQUEEZE + input_seq = input_seq.squeeze().to(cuda) + B = input_seq.size(0) + + + # MODEL OUTPUTS LOSSES, ACCURACIES, Z + res_losses, res_accs, _ = model(input_seq) + + + # UPDATE THE RUNNING LOSSES AND ACCURACIES + for i,(res_loss, res_acc) in enumerate(zip(res_losses, res_accs)): + losses[i].update(res_loss.item(), B) + for j in range(3): + accuracies[i][j].update(res_acc[j].item(), B) + + + # PERFORM BACKWARD(S) AND BACK-PROPAGATION + optimizer.zero_grad() + for loss in res_losses: + loss.backward() + optimizer.step() + + + # PRINT PERFORMANCES INDEXES AT EVERY EPOCH + for loss, acc in zip(losses, accuracies): + print('Training loss: {:.4f} | top1: {:.4f} | top3: {:.4f} | top5: {:.4f}'.format(loss.avg, acc[0].avg, acc[1].avg ,acc[2].avg)) + return [loss.local_avg for loss in losses], [[acc[0].avg,acc[1].avg,acc[2].avg] for acc in accuracies] + + + +def validate(model, data_loader, epoch): + cuda = torch.device('cuda') + # SET THE LOSSES AND ACCURACIES + # WARNING: USING x*[OBJECT] DUPLICATES REFERENCES TO THE SAME OBJECT INSTANCE + # HENCE THE FOR LOOP + losses = [] + accuracies = [] + for i in range(model.get_nb_losses()): + losses.append(AverageMeter()) + accuracies.append([AverageMeter(),AverageMeter(),AverageMeter()]) + + model.eval() + + for input_seq in tqdm(data_loader): + # INPUT HAS DIMENSION (B, 1, 3, T, X, Y) + # 1 COMES FROM DPC TREATING SEQUENCES OF SEQUENCES,SOLVED WITH SQUEEZE + input_seq = input_seq.squeeze().to(cuda) + B = input_seq.size(0) + + + # MODEL OUTPUTS LOSSES, ACCURACIES, Z + res_losses, res_accs, _ = model(input_seq) + + + # UPDATE THE RUNNING LOSSES AND ACCURACIES + for i,(res_loss, res_acc) in enumerate(zip(res_losses, res_accs)): + losses[i].update(res_loss.item(), B) + for j in range(3): + accuracies[i][j].update(res_acc[j].item(), B) + + + # PRINT PERFORMANCES INDEXES AT EVERY EPOCH + for loss, acc in zip(losses, accuracies): + print('Validation loss: {:.4f} | top1: {:.4f} | top3: {:.4f} | top5: {:.4f}'.format(loss.avg, acc[0].avg, acc[1].avg ,acc[2].avg)) + return [loss.local_avg for loss in losses], [[acc[0].avg,acc[1].avg,acc[2].avg] for acc in accuracies] \ No newline at end of file diff --git a/video/utils.py b/video/utils.py new file mode 100644 index 0000000..8723fd1 --- /dev/null +++ b/video/utils.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +from collections import deque +import numpy as np + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.local_history = deque([]) + self.local_avg = 0 + self.history = [] + self.dict = {} # save all data values here + self.save_dict = {} # save mean and std here, for summary table + + def update(self, val, n=1, history=0, step=5): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + if history: + self.history.append(val) + if step > 0: + self.local_history.append(val) + if len(self.local_history) > step: + self.local_history.popleft() + self.local_avg = np.average(self.local_history) + + def dict_update(self, val, key): + if key in self.dict.keys(): + self.dict[key].append(val) + else: + self.dict[key] = [val] + + def __len__(self): + return self.count + \ No newline at end of file diff --git a/vision/.gitignore b/vision/.gitignore new file mode 100644 index 0000000..84806fc --- /dev/null +++ b/vision/.gitignore @@ -0,0 +1,29 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ + +.idea/ + +## do not push heavy files which are created during training to github +logs/ +datasets/ \ No newline at end of file diff --git a/vision/GreedyInfoMax/__init__.py b/vision/GreedyInfoMax/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/vision/GreedyInfoMax/utils/__init__.py b/vision/GreedyInfoMax/utils/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/vision/GreedyInfoMax/utils/logger.py b/vision/GreedyInfoMax/utils/logger.py new file mode 100755 index 0000000..b4f40fb --- /dev/null +++ b/vision/GreedyInfoMax/utils/logger.py @@ -0,0 +1,192 @@ +import os +import torch +import matplotlib.pyplot as plt +import numpy as np +import copy + + +class Logger: + def __init__(self, opt): + self.opt = opt + + if opt.validate: + self.val_loss = [[] for i in range(opt.model_splits)] + else: + self.val_loss = None + + self.train_loss = [[] for i in range(opt.model_splits)] + + if opt.start_epoch > 0: + self.loss_last_training = np.load( + os.path.join(opt.model_path, "train_loss.npy") + ).tolist() + self.train_loss[:len(self.loss_last_training)] = copy.deepcopy(self.loss_last_training) + + + if opt.validate: + self.val_loss_last_training = np.load( + os.path.join(opt.model_path, "val_loss.npy") + ).tolist() + self.val_loss[:len(self.val_loss_last_training)] = copy.deepcopy(self.val_loss_last_training) + else: + self.val_loss = None + else: + self.loss_last_training = None + + if opt.validate: + self.val_loss = [[] for i in range(opt.model_splits)] + else: + self.val_loss = None + + self.num_models_to_keep = 1 + assert self.num_models_to_keep > 0, "Dont delete all models!!!" + + def create_log( + self, + model, + accuracy=None, + epoch=0, + optimizer=None, + final_test=False, + final_loss=None, + acc5=None, + classification_model=None + ): + + print("Saving model and log-file to " + self.opt.log_path) + + # Save the model checkpoint + if self.opt.experiment == "vision": + for idx, layer in enumerate(model.module.encoder): + torch.save( + layer.state_dict(), + os.path.join(self.opt.log_path, "model_{}_{}.ckpt".format(idx, epoch)), + ) + else: + torch.save( + model.state_dict(), + os.path.join(self.opt.log_path, "model_{}.ckpt".format(epoch)), + ) + + ### remove old model files to keep dir uncluttered + if (epoch - self.num_models_to_keep) % 10 != 0: + try: + if self.opt.experiment == "vision": + for idx, _ in enumerate(model.module.encoder): + os.remove( + os.path.join( + self.opt.log_path, + "model_{}_{}.ckpt".format(idx, epoch - self.num_models_to_keep), + ) + ) + else: + os.remove( + os.path.join( + self.opt.log_path, + "model_{}.ckpt".format(epoch - self.num_models_to_keep), + ) + ) + except: + print("not enough models there yet, nothing to delete") + + + if classification_model is not None: + # Save the predict model checkpoint + torch.save( + classification_model.state_dict(), + os.path.join(self.opt.log_path, "classification_model_{}.ckpt".format(epoch)), + ) + + ### remove old model files to keep dir uncluttered + try: + os.remove( + os.path.join( + self.opt.log_path, + "classification_model_{}.ckpt".format(epoch - self.num_models_to_keep), + ) + ) + except: + print("not enough models there yet, nothing to delete") + + if optimizer is not None: + for idx, optims in enumerate(optimizer): + torch.save( + optims.state_dict(), + os.path.join( + self.opt.log_path, "optim_{}_{}.ckpt".format(idx, epoch) + ), + ) + + try: + os.remove( + os.path.join( + self.opt.log_path, + "optim_{}_{}.ckpt".format( + idx, epoch - self.num_models_to_keep + ), + ) + ) + except: + print("not enough models there yet, nothing to delete") + + # Save hyper-parameters + with open(os.path.join(self.opt.log_path, "log.txt"), "w+") as cur_file: + cur_file.write(str(self.opt)) + if accuracy is not None: + cur_file.write("Top 1 - accuracy: " + str(accuracy)) + if acc5 is not None: + cur_file.write("Top 5 - Accuracy: " + str(acc5)) + if final_test and accuracy is not None: + cur_file.write(" Very Final testing accuracy: " + str(accuracy)) + if final_test and acc5 is not None: + cur_file.write(" Very Final testing top 5 - accuracy: " + str(acc5)) + + # Save losses throughout training and plot + np.save( + os.path.join(self.opt.log_path, "train_loss"), np.array(self.train_loss) + ) + + if self.val_loss is not None: + np.save( + os.path.join(self.opt.log_path, "val_loss"), np.array(self.val_loss) + ) + + self.draw_loss_curve() + + if accuracy is not None: + np.save(os.path.join(self.opt.log_path, "accuracy"), accuracy) + + if final_test: + np.save(os.path.join(self.opt.log_path, "final_accuracy"), accuracy) + np.save(os.path.join(self.opt.log_path, "final_loss"), final_loss) + + + def draw_loss_curve(self): + for idx, loss in enumerate(self.train_loss): + lst_iter = np.arange(len(loss)) + plt.plot(lst_iter, np.array(loss), "-b", label="train loss") + + if self.loss_last_training is not None and len(self.loss_last_training) > idx: + lst_iter = np.arange(len(self.loss_last_training[idx])) + plt.plot(lst_iter, self.loss_last_training[idx], "-g") + + if self.val_loss is not None and len(self.val_loss) > idx: + lst_iter = np.arange(len(self.val_loss[idx])) + plt.plot(lst_iter, np.array(self.val_loss[idx]), "-r", label="val loss") + + plt.xlabel("epoch") + plt.ylabel("loss") + plt.legend(loc="upper right") + # plt.axis([0, max(200,len(loss)+self.opt.start_epoch), 0, -round(np.log(1/(self.opt.negative_samples+1)),1)]) + + # save image + plt.savefig(os.path.join(self.opt.log_path, "loss_{}.png".format(idx))) + plt.close() + + def append_train_loss(self, train_loss): + for idx, elem in enumerate(train_loss): + self.train_loss[idx].append(elem) + + def append_val_loss(self, val_loss): + for idx, elem in enumerate(val_loss): + self.val_loss[idx].append(elem) diff --git a/vision/GreedyInfoMax/utils/model_utils.py b/vision/GreedyInfoMax/utils/model_utils.py new file mode 100755 index 0000000..d186d7f --- /dev/null +++ b/vision/GreedyInfoMax/utils/model_utils.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import os + + +def distribute_over_GPUs(opt, model, num_GPU): + ## distribute over GPUs + if opt.device.type != "cpu": + if num_GPU is None: + model = nn.DataParallel(model) + num_GPU = torch.cuda.device_count() + opt.batch_size_multiGPU = opt.batch_size * num_GPU + else: + assert ( + num_GPU <= torch.cuda.device_count() + ), "You cant use more GPUs than you have." + model = nn.DataParallel(model, device_ids=list(range(num_GPU))) + opt.batch_size_multiGPU = opt.batch_size * num_GPU + else: + model = nn.DataParallel(model) + opt.batch_size_multiGPU = opt.batch_size + + model = model.to(opt.device) + print("Let's use", num_GPU, "GPUs!") + + return model, num_GPU + + +def genOrthgonal(dim): + a = torch.zeros((dim, dim)).normal_(0, 1) + q, r = torch.qr(a) + d = torch.diag(r, 0).sign() + diag_size = d.size(0) + d_exp = d.view(1, diag_size).expand(diag_size, diag_size) + q.mul_(d_exp) + return q + + +def makeDeltaOrthogonal(weights, gain): + rows = weights.size(0) + cols = weights.size(1) + if rows > cols: + print("In_filters should not be greater than out_filters.") + weights.data.fill_(0) + dim = max(rows, cols) + q = genOrthgonal(dim) + mid1 = weights.size(2) // 2 + mid2 = weights.size(3) // 2 + with torch.no_grad(): + weights[:, :, mid1, mid2] = q[: weights.size(0), : weights.size(1)] + weights.mul_(gain) + + +def reload_weights(opt, model, optimizer, reload_model): + ## reload weights for training of the linear classifier + if (opt.model_type == 0) and reload_model: + print("Loading weights from ", opt.model_path) + + if opt.experiment == "audio": + model.load_state_dict( + torch.load( + os.path.join(opt.model_path, "model_{}.ckpt".format(opt.model_num)), + map_location=opt.device.type, + ) + ) + else: + for idx, layer in enumerate(model.module.encoder): + model.module.encoder[idx].load_state_dict( + torch.load( + os.path.join( + opt.model_path, + "model_{}_{}.ckpt".format(idx, opt.model_num), + ), + map_location=opt.device.type, + ) + ) + + ## reload weights and optimizers for continuing training + elif opt.start_epoch > 0: + print("Continuing training from epoch ", opt.start_epoch) + + if opt.experiment == "audio": + model.load_state_dict( + torch.load( + os.path.join( + opt.model_path, "model_{}.ckpt".format(opt.start_epoch) + ), + map_location=opt.device.type, + ), + strict=False, + ) + else: + for idx, layer in enumerate(model.module.encoder): + model.module.encoder[idx].load_state_dict( + torch.load( + os.path.join( + opt.model_path, + "model_{}_{}.ckpt".format(idx, opt.start_epoch), + ), + map_location=opt.device.type, + ) + ) + + for i, optim in enumerate(optimizer): + if opt.model_splits > 3 and i > 2: + break + optim.load_state_dict( + torch.load( + os.path.join( + opt.model_path, + "optim_{}_{}.ckpt".format(str(i), opt.start_epoch), + ), + map_location=opt.device.type, + ) + ) + else: + print("Randomly initialized model") + + return model, optimizer \ No newline at end of file diff --git a/vision/GreedyInfoMax/utils/utils.py b/vision/GreedyInfoMax/utils/utils.py new file mode 100755 index 0000000..3bbebea --- /dev/null +++ b/vision/GreedyInfoMax/utils/utils.py @@ -0,0 +1,69 @@ +import os +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from sklearn.manifold import TSNE +import torch + + +def get_device(opt, input_tensor): + if opt.device.type != "cpu": + cur_device = input_tensor.get_device() + else: + cur_device = opt.device + + return cur_device + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size).item()) + return res + + +def scatter(opt, x, colors, label): + """ + creates scatter plot for t-SNE visualization + :param x: 2-D latent space as output by t-SNE + :param colors: labels for each datapoint in x, used to assign different colors to them + :param idx: used for naming the file, to be able to track progress throughout training + """ + # We choose a color palette with seaborn. + palette = np.array(sns.color_palette("hls", 10)) + + # We create a scatter plot. + plt.figure(figsize=(8, 8)) + ax = plt.subplot(aspect="equal") + ax.scatter(x[:, 0], x[:, 1], lw=0, s=40, c=palette[colors.ravel().astype(np.int)]) + plt.xlim(-25, 25) + plt.ylim(-25, 25) + ax.axis("off") + ax.axis("tight") + + plt.savefig( + os.path.join(opt.log_path_latent, "latent_space_{}.png".format(label)), dpi=120 + ) + plt.close() + + +def fit_TSNE_and_plot(opt, feature_space, speaker_labels, label): + feature_space = np.reshape( + feature_space, (np.shape(feature_space)[0] * np.shape(feature_space)[1], -1) + ) + speaker_labels = np.reshape(speaker_labels, (-1, 1)) + + # X: array, shape(n_samples, n_features) + projection = TSNE().fit_transform(feature_space) + + scatter(opt, projection, speaker_labels, label) diff --git a/vision/GreedyInfoMax/vision/__init__.py b/vision/GreedyInfoMax/vision/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/vision/GreedyInfoMax/vision/arg_parser/__init__.py b/vision/GreedyInfoMax/vision/arg_parser/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/vision/GreedyInfoMax/vision/arg_parser/arg_parser.py b/vision/GreedyInfoMax/vision/arg_parser/arg_parser.py new file mode 100755 index 0000000..a1e86a8 --- /dev/null +++ b/vision/GreedyInfoMax/vision/arg_parser/arg_parser.py @@ -0,0 +1,48 @@ +from optparse import OptionParser +import time +import os +import torch +import numpy as np + +from GreedyInfoMax.vision.arg_parser import reload_args, train_args, general_args + + +def parse_args(): + # load parameters and options + parser = OptionParser() + + parser = general_args.parse_general_args(parser) + parser = train_args.parse_train_args(parser) + parser = reload_args.parser_reload_args(parser) + + (opt, _) = parser.parse_args() + + opt.time = time.ctime() + + # Device configuration + opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + opt.experiment = "vision" + + return opt + + +def create_log_path(opt, add_path_var=""): + unique_path = False + + if opt.save_dir != "": + opt.log_path = os.path.join(opt.data_output_dir, "logs", opt.save_dir) + unique_path = True + elif add_path_var == "features" or add_path_var == "images": + opt.log_path = os.path.join(opt.data_output_dir, "logs", add_path_var, os.path.basename(opt.model_path)) + unique_path = True + else: + opt.log_path = os.path.join(opt.data_output_dir, "logs", add_path_var, opt.time) + + # hacky way to avoid overwriting results of experiments when they start at exactly the same time + while os.path.exists(opt.log_path) and not unique_path: + opt.log_path += "_" + str(np.random.randint(100)) + + if not os.path.exists(opt.log_path): + os.makedirs(opt.log_path) + diff --git a/vision/GreedyInfoMax/vision/arg_parser/general_args.py b/vision/GreedyInfoMax/vision/arg_parser/general_args.py new file mode 100755 index 0000000..60c7db0 --- /dev/null +++ b/vision/GreedyInfoMax/vision/arg_parser/general_args.py @@ -0,0 +1,73 @@ +def parse_general_args(parser): + parser.add_option( + "--experiment", + type="string", + default="vision", + help="not a real option, just for bookkeeping", + ) + parser.add_option( + "--dataset", + type="string", + default="stl10", + help="Dataset to use for training, default: stl10", # cifar10, cifar100 + ) + parser.add_option( + "--download_dataset", + action="store_true", + default=False, + help="Boolean to decide whether to download the dataset to train on (only tested for STL-10)", + ) + parser.add_option( + "--num_epochs", type="int", default=300, help="Number of Epochs for Training" + ) + parser.add_option("--seed", type="int", default=2, help="Random seed for training") + parser.add_option("--batch_size", type="int", default=32, help="Batchsize") + parser.add_option( + "-i", + "--data_input_dir", + type="string", + default="./datasets", + help="Directory to store bigger datafiles (dataset and models)", + ) + parser.add_option( + "-o", + "--data_output_dir", + type="string", + default=".", + help="Directory to store bigger datafiles (dataset and models)", + ) + parser.add_option( + "--validate", + action="store_true", + default=False, + help="Boolean to decide whether to split train dataset into train/val and plot validation loss (True) or combine train+validation set for final testing (False)", + ) + parser.add_option( + "--loss", + type="int", + default=0, + help="Loss function to use for training:" + "0 - InfoNCE loss" + "1 - supervised loss using class labels", + ) + parser.add_option( + "--grayscale", + action="store_true", + default=True, + help="Boolean to decide whether to convert images to grayscale (default: true)", + ) + parser.add_option( + "--weight_init", + action="store_true", + default=False, + help="Boolean to decide whether to use special weight initialization (delta orthogonal)", + ) + parser.add_option( + "--save_dir", + type="string", + default="", + help="If given, uses this string to create directory to save results in " + "(be careful, this can overwrite previous results); " + "otherwise saves logs according to time-stamp", + ) + return parser diff --git a/vision/GreedyInfoMax/vision/arg_parser/reload_args.py b/vision/GreedyInfoMax/vision/arg_parser/reload_args.py new file mode 100755 index 0000000..d5de470 --- /dev/null +++ b/vision/GreedyInfoMax/vision/arg_parser/reload_args.py @@ -0,0 +1,59 @@ +from optparse import OptionGroup + +def parser_reload_args(parser): + group = OptionGroup(parser, "Reloading pretrained model options") + + ### Options to load pretrained models + group.add_option( + "--start_epoch", + type="int", + default=0, + help="Epoch to start GIM training from: " + "v=0 - start training from scratch, " + "v>0 - load pre-trained model that was trained for v epochs and continue training " + "(path to pre-trained model needs to be specified in opt.model_path)", + ) + group.add_option( + "--model_path", + type="string", + default=".", + help="Directory of the saved model (path within --data_input_dir)", + ) + group.add_option( + "--model_num", + type="string", + default="100", + help="Number of the saved model to be used for training the linear classifier" + "(loaded using model_path + model_X.ckpt, where X is the model_num passed here)", + ) + group.add_option( + "--model_type", + type="int", + default=0, + help="Which type of model to use for training of linear classifier on downstream task:" + "0 - pretrained GreedyInfoMax/CPC model" + "1 - randomly initialized model" + "2 - fully supervised model", + ) + group.add_option( + "--module_num", + type="int", + default=3, + help="Module to use for training of linear classifier on downstream task (Using 1-indexing). -1 means direct classification on (flattened) images.", + ) + group.add_option( + "--in_channels", + type=int, + default=None, + help="Option to explicitly specify the number of input channels for the linear classifier." + "If None, the default options for resnet output is taken", + ) + group.add_option( + "--save_vars_for_update_calc", + type=int, + default=-1, + help="Save intermediate activation for manual update calculation at given layer (1-6). CAREFUL: This constantly increases model size! Only apply to one update!", + ) + + parser.add_option_group(group) + return parser diff --git a/vision/GreedyInfoMax/vision/arg_parser/train_args.py b/vision/GreedyInfoMax/vision/arg_parser/train_args.py new file mode 100755 index 0000000..95ab1bf --- /dev/null +++ b/vision/GreedyInfoMax/vision/arg_parser/train_args.py @@ -0,0 +1,213 @@ +from optparse import OptionGroup + +def parse_train_args(parser): + group = OptionGroup(parser, "Training options") + group.add_option( + "--learning_rate", + type="float", + default=2e-4, + help="Learning rate (for ADAM optimiser)" + ) + group.add_option( + "--weight_decay", + type="float", + default=0., + help="weight decay or l2-penalty on weights (for ADAM optimiser, default = 0., i.e. no l2-penalty)" + ) + group.add_option( + "--prediction_step", + type="int", + default=5, + help="(Number of) Time steps to predict into future", + ) + group.add_option( + "--gradual_prediction_steps", + action="store_true", + default=False, + help="Increase number of time steps (to predict into future) module by module. This is meant to be used with 6 modules", + ) + group.add_option( + "--reduced_patch_pooling", + action="store_true", + default=False, + help="Reduce adaptive average pooling of patch encodings. This means that some spatial information is kept." + "The dimension of context and target vectors grow accordingly. This is meant to be used with 6 modules.", + ) + group.add_option( + "--negative_samples", + type="int", + default=16, + help="Number of negative samples to be used for training", + ) + group.add_option( + "--current_rep_as_negative", + action="store_true", + default=False, + help="Use the current feature vector ('context' at time t as opposed to predicted time step t+k) itself as/for sampling the negative sample", + ) + group.add_option( + "--sample_negs_locally", + action="store_true", + default=False, + help="Sample neg. samples from batch but within same location in image, i.e. no shuffling across locations", + ) + group.add_option( + "--sample_negs_locally_same_everywhere", + action="store_true", + default=False, + help="Extension of --sample_negs_locally_same_everywhere (must be True). No shuffling across locations and same sample (from batch) for all locations. I.e. negative sample is simply a new input without any scrambling", + ) + group.add_option( + "--either_pos_or_neg_update", + action="store_true", + default=False, + help="Randomly chose to do either pos or neg update in Hinge loss. --negative_samples should be 1. Only used with --current_rep_as_negative True", + ) + group.add_option( + "--patch_size", + type="int", + default=16, + help="Encoding patch size. Use single integer for same encoding size for all modules (default=16)", + ) + group.add_option( + "--increasing_patch_size", + action="store_true", + default=False, + help="Boolean: start with patch size 4 and increase by factors 2 per module until max. patch size = --patch_size (e.g. 16)", + ) + group.add_option( + "--random_crop_size", + type="int", + default=64, + help="Size of the random crop window. Use single integer for same size for all modules (default=64)", + ) + group.add_option( + "--inpatch_prediction", + action="store_true", + default=False, + help="Boolean: change CPC task to smaller scale prediction (within patch -> smaller receptive field) by extra unfolding ", + ) + group.add_option( + "--inpatch_prediction_limit", + type="int", + default=2, + help="Number of module below which inpatch prediction is applied (if inpatch prediction is active) (default=2, i.e. modules 0 and 1 are doing inpatch prediction)", + ) + group.add_option( + "--feedback_gating", + action="store_true", + default=False, + help="Boolean: use feedback from higher layers to gate lower layer plasticity", + ) + group.add_option( + "--gating_av_over_preds", + action="store_true", + default=False, + help="Boolean: average feedback gating (--feedback_gating) from higher layers over different prediction steps ('k')", + ) + group.add_option( + "--contrast_mode", + type="str", + default="multiclass", + help="decides whether constrasting with neg. examples is done at once 'mutliclass' " + "or one at a time with (and then averaged) with CE 'binary', BCE 'logistic' or 'hinge' loss", + ) + group.add_option( + "--detach_c", + action="store_true", + default=False, + help="Boolean whether the gradient of the context c should be dropped (detached)", + ) + group.add_option( + "--encoder_type", + type="str", + default="resnet", + help="Select the encoder type: resnet or vgg_like", + ) + group.add_option( + "--inference_recurrence", + type="int", + default=0, + help="recurrence (on the module level) during inference (before evaluating loss):" + "0 - no recurrence" + "1 - lateral recurrence within layer" + "2 - feedback recurrence" + "3 - both, lateral and feedback recurrence", + ) + group.add_option( + "--recurrence_iters", + type="int", + default=5, + help="number of iterations for inference recurrence (without recurrence, --inference_recurrence == 0, it is set to 0) ", + ) + group.add_option( + "--model_splits", + type="int", + default=3, + help="Number of individually trained modules that the original model should be split into " + "options: 1 (normal end-to-end backprop) or 3 (default used in experiments of paper)", + ) + group.add_option( + "--train_module", + type="int", + default=3, + help="Index of the module to be trained individually (0-2), " + "or training network as one (3)", + ) + group.add_option( + "--predict_module_num", + type="str", + default="same", + help="Option whether W should predict activities in the same module ('same', default), " + "one module below with first module predicting same module ('-1')," + "both ('both') or" + "one module below with last module predicting same module ('-1b')", + ) + group.add_option( + "--extra_conv", + action="store_true", + default=False, + help="Boolian whether extra convolutional layer too increase rec. field size (with downsampling, i.e. stride > 1)" + "is used to decode activity before avg-pooling and contrastive loss", + ) + group.add_option( + "--asymmetric_W_pred", + action="store_true", + default=False, + help="Boolean: solve weight transport in W_pred by using two distinct W_pred(1,2) and splitting the score:" + "Loss(u) -> Loss1(u1) + Loss2(u2) for both, pos. and neg. samples, with" + "u = z*W_pred*c -> u1 = drop_grad(z)*W_pred1*c, u2 = z*W_pred2*drop_grad(c)", + ) + group.add_option( + "--freeze_W_pred", + action="store_true", + default=False, + help="Boolean whether the k prediction weights W_pred (W_k in InfoNCE_Loss) are frozen (require_grad=False).", + ) + group.add_option( + "--unfreeze_last_W_pred", + action="store_true", + default=False, + help="Boolean whether the k prediction weights W_pred of the last module should be unfrozen.", + ) + group.add_option( + "--skip_upper_c_update", + action="store_true", + default=False, + help="Boolean whether extra update in upper (context) layer is skipped. Consider this when predicting lower modules", + ) + group.add_option( + "--no_gamma", + action="store_true", + default=False, + help="Boolean whether gamma (factor which sets the opposite sign of the update for pos and neg samples) is set to 1. i.e. third factor omitted in learning rule", + ) + group.add_option( + "--no_pred", + action="store_true", + default=False, + help="Boolean whether Wpred * c is set to 1 (no prediction). i.e. fourth factor omitted in learning rule", + ) + + parser.add_option_group(group) + return parser diff --git a/vision/GreedyInfoMax/vision/compare_updates.py b/vision/GreedyInfoMax/vision/compare_updates.py new file mode 100755 index 0000000..34a3ec5 --- /dev/null +++ b/vision/GreedyInfoMax/vision/compare_updates.py @@ -0,0 +1,389 @@ +# Code to compare numerically the updates stemming from (i) CLAPP learning rules and (ii) CLAPP loss + autodiff in pytorch +# ATTENTION: This is tested for +# 1) using the same (single) negative everywhere (local sampling): --sample_negs_locally --sample_negs_locally_same_everywhere +# 2) not using W_retro for the moment (i.e. NOT --asymmetric_W_pred) + +# Respective simulations need to be run/created (running 'GreedyInfoMax.vision.main_vision' with the same command line options. +# However the below tests should hold at any point of training, e.g. also at the first epoch of training + +# Bash command for tested cases: +# python -m GreedyInfoMax.vision.compare_updates --download_dataset --save_dir CLAPP_1 --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --contrast_mode 'hinge' --num_epochs 600 --negative_samples 1 --sample_negs_locally --sample_negs_locally_same_everywhere --start_epoch 598 --model_path ./logs/CLAPP_s/ --save_vars_for_update_calc 3 --batch_size 4 +# python -m GreedyInfoMax.vision.compare_updates --download_dataset --save_dir CLAPP_2 --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --contrast_mode 'hinge' --num_epochs 600 --negative_samples 1 --sample_negs_locally --sample_negs_locally_same_everywhere --either_pos_or_neg_update --start_epoch 599 --model_path ./logs/CLAPP/ --save_vars_for_update_calc 3 --batch_size 4 + +################################################################################ + +from shutil import which +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from numpy.random import choice +import time +import os +import code +import sklearn +from IPython import embed +import matplotlib.pyplot as plt +import copy + +## own modules +from GreedyInfoMax.vision.data import get_dataloader +from GreedyInfoMax.vision.arg_parser import arg_parser +from GreedyInfoMax.vision.models import load_vision_model +from GreedyInfoMax.utils import logger, utils + + +def train_iter(opt, model, train_loader): + model.module.switch_calc_loss(True) + + starttime = time.time() + cur_train_module = opt.train_module + + img, label = next(iter(train_loader)) + + model_input = img.to(opt.device) + label = label.to(opt.device) + loss, loss_gated, _, z, accuracy = model(model_input, label, n=cur_train_module) + loss = torch.mean(loss, 0) # take mean over outputs of different GPUs + + model.zero_grad() + for idx in range(len(loss)): + if idx == len(loss) - 1: # last module + loss[idx].backward() + else: + # select loss or loss_gated for gradient descent + loss_for_grads = loss_gated[idx] if model.module.opt.feedback_gating else loss[idx] + loss_for_grads.backward(retain_graph=True) + + +def _load_activations(opt, layer, k): + which_update = torch.load(os.path.join(opt.model_path, 'saved_which_update_layer_'+str(layer)), map_location=torch.device('cpu')) + + (context, z_p, z_n, rand_index) = torch.load(os.path.join(opt.model_path, 'saved_c_and_z_layer_'+str(layer)+'_k_'+str(k)), map_location=torch.device('cpu')) + context = context.squeeze(-2) + # all size: y (red.), x, b, c + + (loss_p, loss_n) = torch.load(os.path.join(opt.model_path, 'saved_losses_layer_'+str(layer)+'_k_'+str(k)), map_location=torch.device('cpu')) + # b, 1, y (red.), x + dloss_p = - torch.sign(loss_p.squeeze(1).permute(1, 2, 0)) + dloss_n = torch.sign(loss_n.squeeze(1).permute(1, 2, 0)) + # y (red.), x, b + + return which_update, context, z_p, z_n, rand_index, dloss_p, dloss_n + +def _add_pos_and_neg_dW(which_update, dW_p, dW_n): + def _select_dWs(dW, inds): + if sum(inds) > 0: # exclude empty sets which lead no NaN in loss + if len(dW.shape) == 5: + dW = dW[:, :, inds, :, :] + elif len(dW.shape) == 7: + dW = dW[:, :, inds, :, :, :, :] + else: + dW[:] = 0. + return dW + + if type(which_update) != str: + inds_p = torch.tensor((which_update == 'pos').tolist()) + inds_n = torch.tensor((which_update == 'neg').tolist()) + dW_p = _select_dWs(dW_p, inds_p) + dW_n = _select_dWs(dW_n, inds_n) + + dW_p_m = dW_p.mean(dim=(0,1,2)) # mean over y (red.), x, b + dW_n_m = dW_n.mean(dim=(0,1,2)) + + dW = dW_p_m + dW_n_m + return dW + +def _get_dWpred(opt, layer, k): + which_update, context, z_p, z_n, _, dloss_p, dloss_n = _load_activations(opt, layer, k) + + # post * pre + dWpred_p = torch.einsum("yxbc, yxbd -> yxbcd", context, z_p) + dWpred_n = torch.einsum("yxbc, yxbd -> yxbcd", context, z_n) + # n_p_y (red.), n_p_x, b, c, c + + # * "gamma" + dWpred_p = dloss_p.unsqueeze(-1).unsqueeze(-1) * dWpred_p + dWpred_n = dloss_n.unsqueeze(-1).unsqueeze(-1) * dWpred_n + + dWpred = _add_pos_and_neg_dW(which_update, dWpred_p, dWpred_n) # c, c + + dWpred /= opt.prediction_step # this factor appears in loss + + return dWpred + +def compare_Wpred(opt, model, k): + layer = opt.save_vars_for_update_calc + + # update acc. to CLAPP rule + dWpred = _get_dWpred(opt, layer, k) + + grad_Wpred = model.module.model[0][layer-1].loss.W_k[k-1].weight.grad.squeeze().clone().detach().to('cpu') + # model.module.model[0][layer-1].loss_mirror.W_k[0].weight.grad + + diff = dWpred - grad_Wpred + d = diff.norm() / (dWpred.norm() + grad_Wpred.norm()) + + return diff, d, grad_Wpred, dWpred + + +def _get_dW_ff(opt, layer, skip_step=1): + layer_model = model.module.model[0][layer-1] + conv = layer_model.model[0] + kernel_size = conv.kernel_size[0] + padding = conv.padding[0] + stride = conv.stride[0] + + n_patches = opt.random_crop_size // (opt.patch_size//2) - 1 + layer_inputs_raw = torch.load(os.path.join(opt.model_path, 'saved_input_layer_'+str(opt.save_vars_for_update_calc))) + # padding as in forward path in VGG_like_Encoder.py + pad = nn.ZeroPad2d(padding) # expects 4-dim input: b', c, y, x (b' = b*n_p_y*n_p_y) + layer_inputs_pad = pad(layer_inputs_raw) + s = layer_inputs_pad.shape + layer_inputs = layer_inputs_pad.reshape(-1, n_patches, n_patches, s[1], s[2], s[3]) # b, n_p_y, n_p_x, c, y, x + + # get full, non-average-pooled output of the network + _, out_full_, _, _ = layer_model.forward(layer_inputs_raw, None, 0, n_patches, n_patches, None) # b', c, n_p_y, n_p_x + out_full = out_full_.reshape(-1, n_patches, n_patches, out_full_.shape[1], s[2]-2*padding, s[3]-2*padding) # b, n_p_y, n_p_x, c, y, x + + dW_ff = torch.zeros(conv.weight.shape) # c_post, c_pre, kernel, kernel + for k in range(1, 6): + which_update, context, z_p, z_n, rand_index, dloss_p, dloss_n = _load_activations(opt, layer, k) + + input_z_p = layer_inputs[:, (k + skip_step) :, :, :, :, :] # b, n_p_y (red.), n_p_x, c_pre, y, x (pre for pos. samples) + input_z_n = input_z_p[rand_index, :, :, :, :, :].clone() # b, n_p_y (red.), n_p_x, c_pre, y, x (pre for neg. samples) + input_context = layer_inputs[:, : -(k + skip_step), :, :, :, :] # b, n_p_y (red.), n_p_x, c_pre, y, x (pre for context) + + pre_z_p = input_z_p.unfold(4, kernel_size, stride).unfold(5, kernel_size, stride).permute(0, 1, 2, 4, 5, 3, 6, 7) # b, n_p_y (red.), n_p_x, y, x, c_pre, kernel_size, kernel_size + pre_z_n = input_z_n.unfold(4, kernel_size, stride).unfold(5, kernel_size, stride).permute(0, 1, 2, 4, 5, 3, 6, 7) + pre_context = input_context.unfold(4, kernel_size, stride).unfold(5, kernel_size, stride).permute(0, 1, 2, 4, 5, 3, 6, 7) + + # rho'(a), for ReLU -> sign function + out_z_p = out_full[:, (k + skip_step) :, :, :, :, :] # b, n_p_y (red.), n_p_x, c_post, y, x + out_z_n = out_z_p[rand_index, :, :, :, :, :].clone() + out_c = out_full[:, : -(k + skip_step), :, :, :, :] + + post_z_p = torch.sign(out_z_p).permute(0, 1, 2, 4, 5, 3) # b, n_p_y (red.), n_p_x, y, x, c_post + post_z_n = torch.sign(out_z_n).permute(0, 1, 2, 4, 5, 3) + post_context = torch.sign(out_c).permute(0, 1, 2, 4, 5, 3) + + # post * pre + dW_ff_k_p_pred = torch.einsum("bpqyxc, bpqyxdst -> bpqyxcdst", post_z_p.to('cuda'), pre_z_p.to('cuda')).mean(dim=(3,4)).to('cpu') # b, n_p_y, n_p_x, b, c_post, c_pre, k_s, k_s (already av. over x and y positions) + dW_ff_k_n_pred = torch.einsum("bpqyxc, bpqyxdst -> bpqyxcdst", post_z_n.to('cuda'), pre_z_n.to('cuda')).mean(dim=(3,4)).to('cpu') + dW_ff_k_retro = torch.einsum("bpqyxc, bpqyxdst -> bpqyxcdst", post_context.to('cuda'), pre_context.to('cuda')).mean(dim=(3,4)).to('cpu') + + # * "dendrite" (using transposed Wpred as Wretro!!) + # In InfoNCE_Loss.py W_k is implemented with z as input -> W_k = W_retro! + W_retro = copy.deepcopy(model.module.model[0][layer-1].loss.W_k[k-1].to('cpu')) # k - 1 because of zero-indexing! + W_pred = copy.deepcopy(W_retro) + W_pred.weight.data = W_retro.weight.permute(1, 0, 2, 3).clone().detach() + + pred = W_pred.forward(context.permute(2,3,0,1)).permute(0,2,3,1) # prediction (same for pos and neg): b, n_p_y (red.), n_p_x, c_post + retro_p = W_retro.forward(z_p.permute(2,3,0,1)).permute(0,2,3,1) # retrodiction for pos. sample + retro_n = W_retro.forward(z_n.permute(2,3,0,1)).permute(0,2,3,1) # retrodiction for neg. sample + + dW_ff_k_p_pred = pred.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_k_p_pred + dW_ff_k_p_retro = retro_p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_k_retro + dW_ff_k_n_pred = pred.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_k_n_pred + dW_ff_k_n_retro = retro_n.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_k_retro + + # add contribution of pred and retro + dW_ff_k_p = (dW_ff_k_p_pred + dW_ff_k_p_retro).permute(1,2,0,3,4,5,6) # n_p_y (red.), n_p_x, b, c_post, c_pre, kernel_size, kernel_size + dW_ff_k_n = (dW_ff_k_n_pred + dW_ff_k_n_retro).permute(1,2,0,3,4,5,6) + + # --detach_c case + # dW_ff_k_p = dW_ff_k_p_pred.permute(1,2,0,3,4,5,6) # n_p_y (red.), n_p_x, b, c_post, c_pre, kernel_size, kernel_size + # dW_ff_k_n = dW_ff_k_n_pred.permute(1,2,0,3,4,5,6) + + # * "gamma" + dW_ff_k_p = dloss_p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_k_p # n_p_y (red.), n_p_x, b, c_post, c_pre, kernel_size, kernel_size + dW_ff_k_n = dloss_n.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_k_n + + dW_ff_k = _add_pos_and_neg_dW(which_update, dW_ff_k_p, dW_ff_k_n) + dW_ff += dW_ff_k.clone() + + dW_ff /= opt.prediction_step + + return dW_ff + +def compare_W_ff(opt, model): + layer = opt.save_vars_for_update_calc + + dW_ff = _get_dW_ff(opt, layer) + + grad_W_ff = model.module.model[0][layer-1].model[0].weight.grad.squeeze().clone().detach().to('cpu') + # model.module.model[0][layer-1].model[0].bias.grad + + diff_ff = dW_ff - grad_W_ff + d_ff = diff_ff.norm() / (dW_ff.norm() + grad_W_ff.norm()) + + return diff_ff, d_ff, grad_W_ff, dW_ff + + + +############################################################################################################## +# Toy examples + +def _get_loss(Wpred, z, c): + Wpred_z = Wpred.forward(z).permute(2,3,0,1) # y, x, b, c + scores = torch.matmul(c.permute(2,3,0,1).unsqueeze(-2), Wpred_z.unsqueeze(-1)).squeeze() # y, x, b + ones = 0.1 * torch.ones(size=scores.shape, dtype=torch.float32) + zeros = torch.zeros(size=scores.shape, dtype=torch.float32) + loss = torch.where(scores < ones, ones - scores, zeros) # y, x, b + + return loss + +def _compare_Wpred_toy(c, z, loss, Wpred): + dWpred_ = torch.einsum("yxbc, yxbd -> yxbcd", c.permute(2,3,0,1), z.permute(2,3,0,1)) # y, x, b, c, c + dloss = - torch.sign(loss) # y, x, b + dWpred_ = dloss.unsqueeze(-1).unsqueeze(-1) * dWpred_ + dWpred = dWpred_.mean(dim=(0,1,2)) + + # get grad + Wpred_grad = Wpred.weight.grad.squeeze().clone() + + diff = dWpred - Wpred_grad + d = diff.norm() / (dWpred.norm() + Wpred_grad.norm()) + + return diff, d, Wpred_grad, dWpred + +def _compare_W_ff_toy(c, z, c_full, z_full, loss, Wpred, layer, in_z_flat, in_c_flat, s): + dloss = - torch.sign(loss) # y, x, b + + # zero pad + unfold + pad = nn.ZeroPad2d(1) # expects 4-dim input: b', c, y, x + in_z_pad = pad(in_z_flat).reshape(s[0], s[1], s[2], -1, s[4]+2, s[5]+2) # b, n_p_y, n_p_x, c, y, x + in_c_pad = pad(in_c_flat).reshape(s[0], s[1], s[2], -1, s[4]+2, s[5]+2) # +2 because of padding + + pre_z = in_z_pad.unfold(4, 3, 1).unfold(5, 3, 1).permute(0, 1, 2, 4, 5, 3, 6, 7) # b, n_p_y, n_p_x, y, x, c_pre, kernel_size, kernel_size + pre_c = in_c_pad.unfold(4, 3, 1).unfold(5, 3, 1).permute(0, 1, 2, 4, 5, 3, 6, 7) + + # sign post # z_full: b, n_p_y, n_p_x, c_post, y, x + post_z = torch.sign(z_full).permute(0, 1, 2, 4, 5, 3) # b, n_p_y, n_p_x, y, x, c_post + post_c = torch.sign(c_full).permute(0, 1, 2, 4, 5, 3) + + # pre * post + dW_ff_pred_ = torch.einsum("bpqyxc, bpqyxdst -> bpqyxcdst", post_z, pre_z).permute(1,2,0,3,4,5,6,7,8) # n_p_y, n_p_x, b, y, x, c_post, c_pre, k_s, k_s + dW_ff_retro_ = torch.einsum("bpqyxc, bpqyxdst -> bpqyxcdst", post_c, pre_c).permute(1,2,0,3,4,5,6,7,8) + dW_ff_pred = dW_ff_pred_.mean(dim=(3,4)) # n_p_y, n_p_x, b, c_post, c_pre, k_s, k_s (mean over x and y positions) + dW_ff_retro = dW_ff_retro_.mean(dim=(3,4)) + + # dendrite + W_retro = copy.deepcopy(Wpred) # k - 1 because of zero-indexing! + W_pred = copy.deepcopy(W_retro) + W_pred.weight.data = W_retro.weight.permute(1, 0, 2, 3).clone().detach() + + pred = W_pred.forward(c).permute(2,3,0,1) # prediction: n_p_y, n_p_x, b, c_post + retro = W_retro.forward(z).permute(2,3,0,1) # retrodiction + + dW_ff_pred = pred.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_pred + dW_ff_retro = retro.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff_retro + + # add pred and retro + dW_ff = dW_ff_pred + dW_ff_retro # n_p_y, n_p_x, b, c, c, k, k + + # loss gating + dW_ff = dloss.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * dW_ff + + # mean over y, x, b + dW_ff = dW_ff.mean(dim=(0,1,2)) # c, c, k, k + + # get grad + W_ff_grad = layer[0].weight.grad.squeeze().clone() + + diff_ff = dW_ff - W_ff_grad + d_ff = diff_ff.norm() / (dW_ff.norm() + W_ff_grad.norm()) + + return diff_ff, d_ff, W_ff_grad, dW_ff + +def toyex(): + # define network and inputs + W_ff = nn.Conv2d(3, 5, kernel_size=3, padding=1) + nonlin = nn.ReLU() + layer = nn.Sequential(*[W_ff, nonlin]) + Wpred = nn.Conv2d(5, 5, 1, bias=False) + + x_z = F.relu(torch.randn(10,3,8,12)) # b, c_in, y, x + x_c = F.relu(torch.rand(x_z.shape)) + # extract patches + in_z = x_z.unfold(2, 4, 2).unfold(3, 4, 2).permute(0, 2, 3, 1, 4, 5) # b, n_p_y, n_p_x, c_in, y, x + in_c = x_c.unfold(2, 4, 2).unfold(3, 4, 2).permute(0, 2, 3, 1, 4, 5) + s = in_z.shape + + in_z_flat = in_z.reshape(s[0]*s[1]*s[2], s[3], s[4], s[5]) # b * n_p_y, n_p_x, c_in, y, x + in_c_flat = in_c.reshape(s[0]*s[1]*s[2], s[3], s[4], s[5]) + + # forward path + out_z = layer.forward(in_z_flat) # b', c_out, y, x + out_c = layer.forward(in_c_flat) + + z = F.adaptive_avg_pool2d(out_z, 1).squeeze().reshape(s[0], s[1], s[2], -1).permute(0,3,1,2) # b, c_out, n_p_y, n_p_x + c = F.adaptive_avg_pool2d(out_c, 1).squeeze().reshape(s[0], s[1], s[2], -1).permute(0,3,1,2) + + z_full = out_z.reshape(s[0], s[1], s[2], -1, s[4], s[4]) # b, n_p_y, n_p_x, c, y, x + c_full = out_c.reshape(s[0], s[1], s[2], -1, s[4], s[4]) # b, n_p_y, n_p_x, c, y, x + + loss = _get_loss(Wpred, z, c) + l = loss.mean() + + # backward (calculate gradients) + l.backward() + + ## manual update Wpred + diff, d, Wpred_grad, dWpred = _compare_Wpred_toy(c, z, loss, Wpred) + print("d for Wpred: ", d) + + ## manual update W_ff + diff_ff, d_ff, W_ff_grad, dW_ff = _compare_W_ff_toy(c, z, c_full, z_full, loss, Wpred, layer, in_z_flat, in_c_flat, s) + print("d for W_ff: ", d_ff) + + embed() + + +############################################################################################################## + +if __name__ == "__main__": + + #toyex() + + opt = arg_parser.parse_args() + + if opt.model_splits != 6: + raise Exception("Only works for layer-wise CLAPP, i.e. model_splits = 6 for 6 layers!") + + # check for layers 0 (1) and 2 (3) since the others have maxpooling layers! + if opt.save_vars_for_update_calc != 1 and opt.save_vars_for_update_calc != 3: + raise Exception("Comparison between updates and gradients only implemented for layers without trailing MaxPool layers (i.e. layers 1 & 3)") + + opt.training_dataset = "unlabeled" + + if opt.device.type != "cpu": + torch.backends.cudnn.benchmark = True + + # load model + model, optimizer = load_vision_model.load_model_and_optimizer(opt) + + if opt.batch_size != opt.batch_size_multiGPU: + raise Exception("Manual update comparison only supported for 1 GPU. Please use only 1 GPU") + + train_loader, _, supervised_loader, _, test_loader, _ = get_dataloader.get_dataloader( + opt + ) + + # perform one training iter and save reps etc. + train_iter(opt, model, train_loader) + + # check equivalence between grads and updates for Wpred + print("checking equivalence between grads and updates for Wpred...") + for k in range(1,6): + diff, d, grad_Wpred, dWpred = compare_Wpred(opt, model, k) + print("rel. difference between grad and update for Wpred_k for k=", k, ": ", d) + + # check equivalence between grads and updates for W_ff + print("checking equivalence between grads and updates for W_ff...") + diff_ff, d_ff, grad_W_ff, dW_ff = compare_W_ff(opt, model) + print("rel. difference between grad and update for W_ff: ", d_ff) + + embed() + + \ No newline at end of file diff --git a/vision/GreedyInfoMax/vision/data/__init__.py b/vision/GreedyInfoMax/vision/data/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/vision/GreedyInfoMax/vision/data/get_dataloader.py b/vision/GreedyInfoMax/vision/data/get_dataloader.py new file mode 100755 index 0000000..cc32d30 --- /dev/null +++ b/vision/GreedyInfoMax/vision/data/get_dataloader.py @@ -0,0 +1,276 @@ +import torch +import torchvision.transforms as transforms +import torchvision +import os +import numpy as np +from torchvision.transforms import transforms +import torchvision.transforms.functional as TF +from IPython import embed + + +def get_dataloader(opt): + if opt.dataset == "stl10": + print("load STL-10 dataset...") + train_loader, train_dataset, supervised_loader, supervised_dataset, test_loader, test_dataset = get_stl10_dataloader( + opt + ) + elif opt.dataset == "cifar10" or opt.dataset == "cifar100": + train_loader, train_dataset, supervised_loader, supervised_dataset, test_loader, test_dataset = get_cifar_dataloader( + opt + ) + # train_loader and train_dataset are None in this case! + else: + raise Exception("Invalid option") + + # embed() + # raise Exception() + return ( + train_loader, + train_dataset, + supervised_loader, + supervised_dataset, + test_loader, + test_dataset, + ) + + +def get_stl10_dataloader(opt): + base_folder = os.path.join(opt.data_input_dir, "stl10_binary") + + aug = { + "stl10": { + "randcrop": opt.random_crop_size, + "flip": True, + "resize": False, + "pad": False, + "grayscale": opt.grayscale, + "mean": [0.4313, 0.4156, 0.3663], # values for train+unsupervised combined + "std": [0.2683, 0.2610, 0.2687], + "bw_mean": [0.4120], # values for train+unsupervised combined + "bw_std": [0.2570], + } # values for labeled train set: mean [0.4469, 0.4400, 0.4069], std [0.2603, 0.2566, 0.2713] + } + transform_train = transforms.Compose( + [get_transforms(eval=False, aug=aug["stl10"])] + ) + transform_valid = transforms.Compose( + [get_transforms(eval=True, aug=aug["stl10"])] + ) + + unsupervised_dataset = torchvision.datasets.STL10( + base_folder, + split="unlabeled", + transform=transform_train, + download=opt.download_dataset, + ) #set download to True to get the dataset + + train_dataset = torchvision.datasets.STL10( + base_folder, split="train", transform=transform_train, download=opt.download_dataset + ) + + test_dataset = torchvision.datasets.STL10( + base_folder, split="test", transform=transform_valid, download=opt.download_dataset + ) + + # default dataset loaders + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=opt.batch_size_multiGPU, shuffle=True, num_workers=16 + ) + + unsupervised_loader = torch.utils.data.DataLoader( + unsupervised_dataset, + batch_size=opt.batch_size_multiGPU, + shuffle=True, + num_workers=16, + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=opt.batch_size_multiGPU, shuffle=False, num_workers=16 + ) + + # create train/val split + if opt.validate: + print("Use train / val split") + + if opt.training_dataset == "train": + dataset_size = len(train_dataset) + train_sampler, valid_sampler = create_validation_sampler(dataset_size) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=opt.batch_size_multiGPU, + sampler=train_sampler, + num_workers=16, + ) + + elif opt.training_dataset == "unlabeled": + dataset_size = len(unsupervised_dataset) + train_sampler, valid_sampler = create_validation_sampler(dataset_size) + + unsupervised_loader = torch.utils.data.DataLoader( + unsupervised_dataset, + batch_size=opt.batch_size_multiGPU, + sampler=train_sampler, + num_workers=16, + ) + + else: + raise Exception("Invalid option") + + # overwrite test_dataset and _loader with validation set + test_dataset = torchvision.datasets.STL10( + base_folder, + split=opt.training_dataset, + transform=transform_valid, + download=opt.download_dataset, + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=opt.batch_size_multiGPU, + sampler=valid_sampler, + num_workers=16, + ) + + else: + print("Use (train+val) / test split") + + return ( + unsupervised_loader, + unsupervised_dataset, + train_loader, + train_dataset, + test_loader, + test_dataset, + ) + + +def create_validation_sampler(dataset_size): + # Creating data indices for training and validation splits: + validation_split = 0.2 + shuffle_dataset = True + + indices = list(range(dataset_size)) + split = int(np.floor(validation_split * dataset_size)) + if shuffle_dataset: + np.random.shuffle(indices) + train_indices, val_indices = indices[split:], indices[:split] + + # Creating data samplers and loaders: + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices) + + return train_sampler, valid_sampler + +# only exists in v0.9.9 +# class Sharpen: +# """Sharpen image after upsampling with interpolation.""" +# def __call__(self, x): +# return TF.adjust_sharpness(x, 2.) + +def get_transforms(eval=False, aug=None): + trans = [] + + if aug["resize"]: + trans.append(transforms.Resize(aug["resize_size"])) + + if aug["pad"]: + trans.append(transforms.Pad(aug["pad_size"], fill=0, padding_mode='constant')) + + if aug["randcrop"] and not eval: + trans.append(transforms.RandomCrop(aug["randcrop"])) + + if aug["randcrop"] and eval: + trans.append(transforms.CenterCrop(aug["randcrop"])) + + if aug["flip"] and not eval: + trans.append(transforms.RandomHorizontalFlip()) + + if aug["grayscale"]: + trans.append(transforms.Grayscale()) + trans.append(transforms.ToTensor()) + trans.append(transforms.Normalize(mean=aug["bw_mean"], std=aug["bw_std"])) + elif aug["mean"]: + trans.append(transforms.ToTensor()) + trans.append(transforms.Normalize(mean=aug["mean"], std=aug["std"])) + else: + trans.append(transforms.ToTensor()) + + trans = transforms.Compose(trans) + return trans + + +def get_cifar_dataloader(opt): + cor_factor_mean = 0.06912513 # correction factors: STL-10 normalisation lead to these residual mean and std -> has to be adapted to get same input distribution + cor_factor_std = 0.95930314 + if opt.dataset == "cifar10": + print("load cifar10 dataset...") + base_folder = os.path.join(opt.data_input_dir, "cifar10_binary") + bw_mean = 0.47896898 - cor_factor_mean * 0.2392343 / cor_factor_std + bw_std = 0.2392343 / cor_factor_std + elif opt.dataset == "cifar100": + print("load cifar100 dataset...") + base_folder = os.path.join(opt.data_input_dir, "cifar100_binary") + bw_mean = 0.48563015 - cor_factor_mean * 0.25072286 / cor_factor_std + bw_std = 0.25072286 / cor_factor_std + + aug = { + "cifar": { + "resize": False, + "resize_size": 64, # 96 + "pad": False, + "pad_size": 16, + "randcrop": False, #opt.random_crop_size, + "flip": False, + "grayscale": opt.grayscale, + "bw_mean": [bw_mean], + "bw_std": [bw_std], + } + } + # mean and std found as: + # x = np.concatenate([np.asarray(im) for (im, t) in supervised_loader]); np.mean(x); np.std(x) + # CIFAR10 + # for vanilla 32 x 32 input: "bw_mean": [0.47896898], "bw_std": [0.2392343] + # for resize_size: 96 and randcrop: "bw_mean": [0.470379], "bw_std": [0.2249] + # for resize_size: 64 without randcrop: "bw_mean": [0.4798809], "bw_std": [0.23278822] + # for pad: True and pad_size: 16: "bw_mean": [0.11974239], "bw_std": [0.23942184] + # CIFAR100 + # for vanilla 32 x 32 input: "bw_mean": [0.48563015], "bw_std": [0.25072286] + + transform_train = transforms.Compose( + [get_transforms(eval=False, aug=aug["cifar"])] + ) + transform_valid = transforms.Compose( + [get_transforms(eval=True, aug=aug["cifar"])] + ) + + if opt.dataset == "cifar10": + train_dataset = torchvision.datasets.CIFAR10( + base_folder, train=True, transform=transform_train, download=opt.download_dataset + ) + test_dataset = torchvision.datasets.CIFAR10( + base_folder, train=False, transform=transform_valid, download=opt.download_dataset + ) + elif opt.dataset == "cifar100": + train_dataset = torchvision.datasets.CIFAR100( + base_folder, train=True, transform=transform_train, download=opt.download_dataset + ) + test_dataset = torchvision.datasets.CIFAR100( + base_folder, train=False, transform=transform_valid, download=opt.download_dataset + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=opt.batch_size_multiGPU, shuffle=False, num_workers=16 + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=opt.batch_size_multiGPU, shuffle=False, num_workers=16 + ) + + return ( + None, + None, + train_loader, + train_dataset, + test_loader, + test_dataset, + ) \ No newline at end of file diff --git a/vision/GreedyInfoMax/vision/downstream_classification.py b/vision/GreedyInfoMax/vision/downstream_classification.py new file mode 100755 index 0000000..b9f5b7a --- /dev/null +++ b/vision/GreedyInfoMax/vision/downstream_classification.py @@ -0,0 +1,222 @@ +import torch +import numpy as np +import time +import os +import code + +from GreedyInfoMax.vision.data import get_dataloader +from GreedyInfoMax.vision.arg_parser import arg_parser +from GreedyInfoMax.vision.models import load_vision_model +from GreedyInfoMax.utils import logger, utils + + +def train_logistic_regression(opt, context_model, classification_model, train_loader): + total_step = len(train_loader) + classification_model.train() + + starttime = time.time() + + # No randomness in cifar training -> save reps and go without forward through context model! + if opt.dataset == "cifar10" or opt.dataset == "cifar100": + reps = [] + + for epoch in range(opt.num_epochs): + epoch_acc1 = 0 + epoch_acc5 = 0 + + loss_epoch = 0 + + for step, (img, target) in enumerate(train_loader): + + classification_model.zero_grad() + + if (opt.dataset == "cifar10" or opt.dataset == "cifar100") and (epoch != 0): + z = reps[step] + else: + model_input = img.to(opt.device) + + if opt.model_type == 2: # fully supervised training + _, _, z = context_model(model_input) + else: + with torch.no_grad(): + _, _, _, z, _ = context_model(model_input, target, n=opt.module_num) + z = z.detach() # double security that no gradients go to representation learning part of model + + if opt.dataset == "cifar10" or opt.dataset == "cifar100": + reps.append(z) + + prediction = classification_model(z) + + target = target.to(opt.device) + loss = criterion(prediction, target) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # calculate accuracy + acc1, acc5 = utils.accuracy(prediction.data, target, topk=(1, 5)) + epoch_acc1 += acc1 + epoch_acc5 += acc5 + + sample_loss = loss.item() + loss_epoch += sample_loss + + if step % 10 == 0: + print( + "Epoch [{}/{}], Step [{}/{}], Time (s): {:.1f}, Acc1: {:.4f}, Acc5: {:.4f}, Loss: {:.4f}".format( + epoch + 1, + opt.num_epochs, + step, + total_step, + time.time() - starttime, + acc1, + acc5, + sample_loss, + ) + ) + starttime = time.time() + + if opt.validate: + # validate the model - in this case, test_loader loads validation data + val_acc1, _ , val_loss = test_logistic_regression( + opt, context_model, classification_model, test_loader + ) + logs.append_val_loss([val_loss]) + + print("Overall accuracy for this epoch: ", epoch_acc1 / total_step) + logs.append_train_loss([loss_epoch / total_step]) + logs.create_log( + context_model, + epoch=epoch, + classification_model=classification_model, + accuracy=epoch_acc1 / total_step, + acc5=epoch_acc5 / total_step, + ) + + +def test_logistic_regression(opt, context_model, classification_model, test_loader): + total_step = len(test_loader) + context_model.eval() + classification_model.eval() + + starttime = time.time() + + loss_epoch = 0 + epoch_acc1 = 0 + epoch_acc5 = 0 + + for step, (img, target) in enumerate(test_loader): + + model_input = img.to(opt.device) + + with torch.no_grad(): + _, _, _, z, _ = context_model(model_input, target, n=opt.module_num) + + z = z.detach() + + prediction = classification_model(z) + + target = target.to(opt.device) + loss = criterion(prediction, target) + + # calculate accuracy + acc1, acc5 = utils.accuracy(prediction.data, target, topk=(1, 5)) + epoch_acc1 += acc1 + epoch_acc5 += acc5 + + sample_loss = loss.item() + loss_epoch += sample_loss + + if step % 10 == 0: + print( + "Step [{}/{}], Time (s): {:.1f}, Acc1: {:.4f}, Acc5: {:.4f}, Loss: {:.4f}".format( + step, total_step, time.time() - starttime, acc1, acc5, sample_loss + ) + ) + starttime = time.time() + + print("Testing Accuracy: ", epoch_acc1 / total_step) + return epoch_acc1 / total_step, epoch_acc5 / total_step, loss_epoch / total_step + + +if __name__ == "__main__": + + opt = arg_parser.parse_args() + + add_path_var = "linear_model" + + arg_parser.create_log_path(opt, add_path_var=add_path_var) + opt.training_dataset = "train" + + # random seeds + torch.manual_seed(opt.seed) + torch.cuda.manual_seed(opt.seed) + np.random.seed(opt.seed) + + # load pretrained model + # cannot switch opt.reduced_patch_pooling = False here because otherwise W_preds sizes don't match + context_model, _ = load_vision_model.load_model_and_optimizer( + opt, reload_model=True, calc_loss=False + ) + context_model.module.switch_calc_loss(False) + if opt.reduced_patch_pooling: + print("ATTENTION: the option --reduced_patch_pooling is only active for GIM/CLAPP training: " + "For downstream classification (from lower layers), activations are still pooled over the whole patch (in the downstream classifier). " + "Otherwise change line below this warning and change the number of in_channels accordingly.") + for module in context_model.module.encoder: + module.patch_average_pool_out_dim = 1 + + + # model_type=2 is supervised model which trains entire architecture; otherwise just extract features + if opt.model_type != 2: + context_model.eval() + + if opt.module_num==-1: + print("CAREFUL! Training classifier directly on input image! Model is ignored and returns the (flattened) input images!") + + _, _, train_loader, _, test_loader, _ = get_dataloader.get_dataloader(opt) + + classification_model = load_vision_model.load_classification_model(opt) + + if opt.model_type == 2: + params = list(context_model.parameters()) + list(classification_model.parameters()) + else: + params = classification_model.parameters() + + optimizer = torch.optim.Adam(params) + criterion = torch.nn.CrossEntropyLoss() + + logs = logger.Logger(opt) + + try: + # Train the model + train_logistic_regression(opt, context_model, classification_model, train_loader) + + # Test the model + acc1, acc5, _ = test_logistic_regression( + opt, context_model, classification_model, test_loader + ) + + except KeyboardInterrupt: + print("Training got interrupted") + + logs.create_log( + context_model, + classification_model=classification_model, + accuracy=acc1, + acc5=acc5, + final_test=True, + ) + torch.save( + context_model.state_dict(), os.path.join(opt.log_path, "context_model.ckpt") + ) + + np.save(os.path.join(opt.model_path, "classification_results_values_"+str(opt.dataset)+".npy"), + np.array([acc1, acc5])) + L = ["Test top1 classification accuracy: "+str(acc1)+"\n", + "Test top5 classification accuracy: "+str(acc5)+"\n"] + f = open(os.path.join(opt.model_path, "classification_results_"+str(opt.dataset)+".txt"), "w") + f.writelines(L) + f.close() diff --git a/vision/GreedyInfoMax/vision/get_acc_supervised.py b/vision/GreedyInfoMax/vision/get_acc_supervised.py new file mode 100755 index 0000000..12f94a3 --- /dev/null +++ b/vision/GreedyInfoMax/vision/get_acc_supervised.py @@ -0,0 +1,68 @@ +import torch +import numpy as np +import time +import os + +from GreedyInfoMax.vision.data import get_dataloader +from GreedyInfoMax.vision.arg_parser import arg_parser +from GreedyInfoMax.vision.models import load_vision_model +from GreedyInfoMax.utils import logger, utils + + +def get_loss_and_accuracy(opt, model, data_loader): + losses = torch.zeros(opt.model_splits) + accuracies = torch.zeros(opt.model_splits) + for step, (img, target) in enumerate(data_loader): + + model_input = img.to(opt.device) + with torch.no_grad(): + loss, _, _, _, accuracy = model(model_input, target, n=opt.module_num) + + loss = torch.mean(loss, 0) # average over GPUs + accuracy = torch.mean(accuracy, 0) + + for idx in range(opt.model_splits): + losses[idx] += loss[idx] + accuracies[idx] += accuracy[idx] + + if step % 10 == 0: + print("evaluate batch number ", step, " out of ", len(data_loader)) + + return losses / len(data_loader), accuracies / len(data_loader) + +if __name__ == "__main__": + + opt = arg_parser.parse_args() + if opt.loss != 1: + raise ValueError("--loss keyword is not set to 1 (supervised). This only works for the Supervised Model") + + add_path_var = "linear_model" + + arg_parser.create_log_path(opt, add_path_var=add_path_var) + opt.training_dataset = "train" + + model, _ = load_vision_model.load_model_and_optimizer( + opt, reload_model=True, calc_loss=False + ) + model.module.switch_calc_loss(True) + + _, _, train_loader, _, test_loader, _ = get_dataloader.get_dataloader(opt) + + print("Evaluating loss and accuracy on train set...") + loss_train, acc_train = get_loss_and_accuracy(opt, model, train_loader) + print("Evaluating loss and accuracy on test set...") + loss_test, acc_test = get_loss_and_accuracy(opt, model, test_loader) + + np.save(os.path.join(opt.model_path, "classification_results_values.npy"), + np.array([loss_train.numpy(), acc_train.numpy(), loss_test.numpy(), acc_test.numpy()])) + L = ["Training losses for all modules: "+str(loss_train.numpy())+"\n", + "Training accuracies for all modules: "+str(acc_train.numpy())+"\n", + "Testing losses for all modules: "+str(loss_test.numpy())+"\n", + "Testing accuracies for all modules: "+str(acc_test.numpy())+"\n"] + f = open(os.path.join(opt.model_path, "classification_results.txt"), "w") + f.writelines(L) + f.close() + + for l in L: + print(l) + diff --git a/vision/GreedyInfoMax/vision/main_vision.py b/vision/GreedyInfoMax/vision/main_vision.py new file mode 100755 index 0000000..28e4189 --- /dev/null +++ b/vision/GreedyInfoMax/vision/main_vision.py @@ -0,0 +1,126 @@ +import torch +import time +import numpy as np + +from GreedyInfoMax.utils import logger +from GreedyInfoMax.vision.arg_parser import arg_parser +from GreedyInfoMax.vision.models import load_vision_model +from GreedyInfoMax.vision.data import get_dataloader + + +def train(opt, model, train_loader, optimizer): + total_step = len(train_loader) + model.module.switch_calc_loss(True) + + print_idx = 100 + + starttime = time.time() + cur_train_module = opt.train_module + + for epoch in range(opt.start_epoch, opt.num_epochs + opt.start_epoch): + + loss_epoch = [0 for i in range(opt.model_splits)] + loss_updates = [1 for i in range(opt.model_splits)] + + # loop over batches in train_loader + for step, (img, label) in enumerate(train_loader): + + if step % print_idx == 0: + print( + "Epoch [{}/{}], Step [{}/{}], Training Block: {}, Time (s): {:.1f}".format( + epoch + 1, + opt.num_epochs + opt.start_epoch, + step, + total_step, + cur_train_module, + time.time() - starttime, + ) + ) + + starttime = time.time() + + model_input = img.to(opt.device) + label = label.to(opt.device) + + # forward pass through whole model (loop over modules within model forward) + loss, loss_gated, _, _, accuracy = model(model_input, label, n=cur_train_module) + loss = torch.mean(loss, 0) # take mean over outputs of different GPUs + loss_gated = torch.mean(loss_gated, 0) # take mean over outputs of different GPUs + accuracy = torch.mean(accuracy, 0) + + if cur_train_module != opt.model_splits and opt.model_splits > 1: + raise ValueError("Training intermediate modules is not tested!") + # loss = loss[cur_train_module].unsqueeze(0) + + # loop through the losses of the modules and do gradient descent + for idx in range(len(loss)): + if len(loss) == 1 and opt.model_splits != 1: + idx = cur_train_module + + model.zero_grad() + + if idx == len(loss) - 1: # last module + loss[idx].backward() + else: + # select loss or loss_gated for gradient descent + loss_for_grads = loss_gated[idx] if model.module.opt.feedback_gating else loss[idx] + loss_for_grads.backward(retain_graph=True) + + optimizer[idx].step() + if opt.predict_module_num=='-1' or opt.predict_module_num=='both': + if idx != 0: + optimizer[idx-1].step() # to update lower (feature) layer + if opt.predict_module_num=='-1b' and not opt.skip_upper_c_update: + if idx != len(loss) - 1: + optimizer[idx+1].step() # to update upper (context) layer + + # We still output normal (ungated) loss for printing and plotting + print_loss = loss[idx].item() + if opt.asymmetric_W_pred: + print_loss *= 0.5 # loss is double in that case but gradients are still the same -> print the corresponding values + print_acc = accuracy[idx].item() + if step % print_idx == 0: + print("\t \t Loss: \t \t {:.4f}".format(print_loss)) + if opt.loss == 1: + print("\t \t Accuracy: \t \t {:.4f}".format(print_acc)) + + loss_epoch[idx] += print_loss + loss_updates[idx] += 1 + + logs.append_train_loss([x / loss_updates[idx] for idx, x in enumerate(loss_epoch)]) + logs.create_log(model, epoch=epoch, optimizer=optimizer) + +if __name__ == "__main__": + + opt = arg_parser.parse_args() + arg_parser.create_log_path(opt) + opt.training_dataset = "unlabeled" + + # random seeds + torch.manual_seed(opt.seed) + torch.cuda.manual_seed(opt.seed) + np.random.seed(opt.seed) + + if opt.device.type != "cpu": + torch.backends.cudnn.benchmark = True + + # load model + model, optimizer = load_vision_model.load_model_and_optimizer(opt) + + logs = logger.Logger(opt) + + train_loader, _, supervised_loader, _, test_loader, _ = get_dataloader.get_dataloader( + opt + ) + + if opt.loss == 1: + train_loader = supervised_loader + + try: + # Train the model + train(opt, model, train_loader, optimizer) + + except KeyboardInterrupt: + print("Training got interrupted, saving log-files now.") + + logs.create_log(model) diff --git a/vision/GreedyInfoMax/vision/models/ClassificationModel.py b/vision/GreedyInfoMax/vision/models/ClassificationModel.py new file mode 100755 index 0000000..eab6ad1 --- /dev/null +++ b/vision/GreedyInfoMax/vision/models/ClassificationModel.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + + +class ClassificationModel(torch.nn.Module): + def __init__(self, in_channels=256, num_classes=200, hidden_nodes=0): + super().__init__() + self.in_channels = in_channels + self.avg_pool = nn.AvgPool2d((7, 7), stride=0, padding=0) + self.model = nn.Sequential() + + if hidden_nodes > 0: + self.model.add_module( + "layer1", nn.Linear(self.in_channels, hidden_nodes, bias=True) + ) + + self.model.add_module("ReLU", nn.ReLU()) + self.model.add_module("Dropout", nn.Dropout(p=0.5)) + + self.model.add_module( + "layer 2", nn.Linear(hidden_nodes, num_classes, bias=True) + ) + + else: + self.model.add_module( + "layer1", nn.Linear(self.in_channels, num_classes, bias=True) + ) + + print(self.model) + + def forward(self, x, *args): + x = self.avg_pool(x).squeeze() + x = self.model(x).squeeze() + return x diff --git a/vision/GreedyInfoMax/vision/models/FullModel.py b/vision/GreedyInfoMax/vision/models/FullModel.py new file mode 100755 index 0000000..1363b9c --- /dev/null +++ b/vision/GreedyInfoMax/vision/models/FullModel.py @@ -0,0 +1,194 @@ +import torch +import torch.nn as nn + +from GreedyInfoMax.vision.models import VGG_like_Encoder + +class FullVisionModel(torch.nn.Module): + def __init__(self, opt, calc_loss): + super().__init__() + self.opt = opt + self.contrastive_samples = self.opt.negative_samples + if self.opt.current_rep_as_negative: + print("Contrasting against current representation (i.e. only one negative sample)") + else: + print("Contrasting against ", self.contrastive_samples, " negative sample(s)") + self.calc_loss = calc_loss + self.encoder_type = self.opt.encoder_type + print("Using ", self.encoder_type, " encoder") + self.predict_module_num = self.opt.predict_module_num + self.increasing_patch_size = self.opt.increasing_patch_size + if self.predict_module_num=='-1': + print("Predicting lower module, 1st module predicting same module") + elif self.predict_module_num=='both': + print("Predicting both, same and lower module") + elif self.predict_module_num=='-1b': + print("Predicting lower module, last module predicting same module") + + if self.opt.inference_recurrence == 0: # 0 - no recurrence + self.recurrence_iters = 0 + else: # 1 - lateral recurrence within layer, 2 - feedback recurrence, 3 - both, lateral and feedback recurrence + self.recurrence_iters = self.opt.recurrence_iters + + self.model, self.encoder, self.autoregressor = self._create_full_model(opt) + + print(self.model) + + def _create_full_model(self, opt): + if self.encoder_type=='vgg_like': + full_model, encoder = self._create_full_model_vgg(opt) + else: + raise Exception("Invalid encoder option") + + return full_model, encoder, None + + + def _create_full_model_vgg(self, opt): + if type(opt.patch_size) == int: + patch_sizes = [opt.patch_size for _ in range(opt.model_splits)] + else: + patch_sizes = opt.patch_size + + arch = [128, 256, 'M', 256, 512, 'M', 1024, 'M', 1024, 'M'] + if opt.model_splits == 1: + blocks = [arch] + elif opt.model_splits == 2: + blocks = [arch[:4], arch[4:]] + elif opt.model_splits == 4: + blocks = [arch[:4], arch[4:6], arch[6:8], arch[8:]] + elif opt.model_splits == 3: + blocks = [arch[:3], arch[3:6], arch[6:]] + elif opt.model_splits == 6: + blocks = [arch[:1], arch[1:3], arch[3:4], arch[4:6], arch[6:8], arch[8:]] + else: + raise NotImplementedError + + full_model = nn.ModuleList([]) + encoder = nn.ModuleList([]) + + if opt.grayscale: + input_dims = 1 + else: + input_dims = 3 + + output_dims = arch[-2] * 4 + + for idx, _ in enumerate(blocks): + if idx==0: + in_channels = input_dims + else: + if blocks[idx-1][-1] == 'M': + in_channels = blocks[idx-1][-2] + else: + in_channels = blocks[idx-1][-1] + + encoder.append( + VGG_like_Encoder.VGG_like_Encoder(opt, + idx, + blocks, + in_channels, + calc_loss=False, + patch_size=patch_sizes[idx], + ) + ) + + full_model.append(encoder) + + return full_model, encoder + + ########################################################################################################### + # forward + + def forward(self, x, label, n=3): + # n: until which module to perform the forward pass + model_input = x + + if self.opt.device.type != "cpu": + cur_device = x.get_device() + else: + cur_device = self.opt.device + + n_patches_x, n_patches_y = None, None + + + outs = [] + + if n==-1: # return (reshaped/flattened) input image, for direct classification + s = model_input.shape # b, in_channels, y, x + h = model_input.reshape(s[0], s[1]*s[2]*s[3]).unsqueeze(-1).unsqueeze(-1) # b, in_channels*y*x + else: + reps = None + for t in range(self.recurrence_iters+1): # 0-th iter for normal feedforward pass + model_input = x + acts = [] + # forward loop through modules + for idx, module in enumerate(self.encoder[:n]): + # block gradient of h at some point -> should be blocked after one module since input was detached + h, z, n_patches_y, n_patches_x = module( + model_input, reps, t, n_patches_y, n_patches_x, label + ) + # detach z to make sure no gradients are flowing in between modules + # we can detach z here, as for the CPC model the loop is only called once and h is forward-propagated + model_input = z.clone().detach() # full module output + acts.append(model_input) # needed for optional recurrence + if t == self.recurrence_iters: + outs.append(h) # out: mean pooled per patch + + reps = acts + + loss, loss_gated, accuracies = self.evaluate_losses(outs, label, cur_device, n=n) + + c = None # Can be used if context is of different kind than h (e.g. output of recurrent layer) + + return loss, loss_gated, c, h, accuracies + + def evaluate_losses(self, outs, label, cur_device, n = 3): + loss = torch.zeros(1, self.opt.model_splits, device=cur_device) # first dimension for multi-GPU training + loss_gated = torch.zeros(1, self.opt.model_splits, device=cur_device) # first dimension for multi-GPU training + accuracies = torch.zeros(1, self.opt.model_splits, device=cur_device) # first dimension for multi-GPU training + + # loop BACKWARDS through module outs and calculate losses + # backward loop is necessary because of potential feedback gating! + for idx in range(n-1, -1, -1): # backward loop: upper, lower, step + if self.opt.feedback_gating: + if idx == self.opt.model_splits-1: # no gating for highest layer + gating = None + else: + gating = None + + cur_loss, cur_loss_gated, cur_accuracy, gating = self.encoder[idx].evaluate_loss(outs, idx, label, gating=gating) + + if cur_loss is not None: + loss[:, idx] = cur_loss + loss_gated[:, idx] = cur_loss_gated + accuracies[:, idx] = cur_accuracy + + return loss, loss_gated, accuracies + + def switch_calc_loss(self, calc_loss): + # by default models are set to not calculate the loss as it is costly + # this function can enable the calculation of the loss for training + self.calc_loss = calc_loss + if self.opt.model_splits == 1 and self.opt.loss == 0: + if self.employ_autoregressive: + self.autoregressor.calc_loss = calc_loss + else: + self.encoder[-1].calc_loss = calc_loss + + if self.opt.model_splits == 1 and self.opt.loss == 1: + self.encoder[-1].calc_loss = calc_loss + + if self.opt.model_splits > 1: + if self.opt.train_module != self.opt.model_splits: + cont = input("WARNING: model_splits > 1 and train_module != model_splits." + " (this could mean that not all modules are trained; ignore when training classifier)." + " Please think again if you really want that and enter 'y' to continue: ") + if cont == "y": + return + else: + raise ValueError("Interrupting...") + + if self.opt.train_module == self.opt.model_splits: + for i, layer in enumerate(self.encoder): + layer.calc_loss = calc_loss + else: + self.encoder[self.opt.train_module].calc_loss = calc_loss diff --git a/vision/GreedyInfoMax/vision/models/InfoNCE_Loss.py b/vision/GreedyInfoMax/vision/models/InfoNCE_Loss.py new file mode 100755 index 0000000..374ca82 --- /dev/null +++ b/vision/GreedyInfoMax/vision/models/InfoNCE_Loss.py @@ -0,0 +1,391 @@ + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _WeightedLoss +import torch.nn.functional as F +import numpy as np +from numpy.random import choice +from IPython import embed +import os + +from GreedyInfoMax.utils import model_utils + + +class InfoNCE_Loss(nn.Module): + def __init__(self, opt, in_channels, out_channels, prediction_steps, save_vars=False): # in_channels: z, out_channels: c + super().__init__() + self.opt = opt + self.negative_samples = self.opt.negative_samples + self.k_predictions = prediction_steps + self.contrast_mode = self.opt.contrast_mode + self.average_feedback_gating = self.opt.gating_av_over_preds + self.detach_c = self.opt.detach_c + self.current_rep_as_negative = self.opt.current_rep_as_negative + self.sample_negs_locally = self.opt.sample_negs_locally + self.sample_negs_locally_same_everywhere = self.opt.sample_negs_locally_same_everywhere + self.either_pos_or_neg_update = self.opt.either_pos_or_neg_update + self.which_update = 'both' + self.save_vars = save_vars + + if self.current_rep_as_negative: + self.negative_samples = 1 + + self.W_k = nn.ModuleList( + nn.Conv2d(in_channels, out_channels, 1, bias=False) # in_channels: z, out_channels: c + for _ in range(self.k_predictions) + ) + + if self.opt.freeze_W_pred: # freeze prediction weights W_k + if self.opt.unfreeze_last_W_pred: + params_to_freeze = self.W_k[:-1].parameters() + else: + params_to_freeze = self.W_k.parameters() + for p in params_to_freeze: + p.requires_grad = False + + if self.contrast_mode == 'multiclass' or self.contrast_mode == 'binary': + self.contrast_loss = ExpNLLLoss() + elif self.contrast_mode == 'logistic': + self.contrast_loss = MyBCEWithLogitsLoss() + elif self.contrast_mode == 'hinge': + self.contrast_loss = HingeLoss() + + if self.opt.weight_init: + self.initialize() + + + def initialize(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d,)): + if m in self.W_k: + model_utils.makeDeltaOrthogonal( + m.weight, + nn.init.calculate_gain( + "Sigmoid" + ), + ) + + + def forward(self, z, c, skip_step=1, gating=None): + # gating should be either None or (nested) list of gating values for each prediction step (k_predictions) + # z: b, channels, n_patches_y, n_patches_x + if self.detach_c: + c = c.clone().detach() # drop gradient of context + + batch_size = z.shape[0] + + # If self.either_pos_or_neg_update is True, select whether pos or neg update (or both) is done, independently for every sample in batch + # p = [0.5,0.5,0.] for equal sampling, p = [0.,0.,1.] implements normal HingeCPC + if self.either_pos_or_neg_update: + self.which_update = choice(['pos','neg','both'], size = batch_size, replace=True, p = [0.5,0.5,0.]) + + total_loss, total_loss_gated = 0, 0 + gating_out = [] + if gating is not None and self.average_feedback_gating: # average gating over k predictions + g_pos = sum([g[0] for g in gating]) / self.k_predictions + g_neg = sum([g[1] for g in gating]) / self.k_predictions + gating = [g_pos, g_neg] # 2 x b each + + if self.opt.device.type != "cpu": + cur_device = z.get_device() + else: + cur_device = self.opt.device + + # Loop over different prediction intervals, For each element in c, contrast with elements below + for k in range(1, self.k_predictions + 1): + ### compute log f(c_t, x_{t+k}) = z^T_{t+k} W_k c_t + # compute z^T_{t+k} W_k: + ztwk = ( + self.W_k[k - 1] + .forward(z[:, :, (k + skip_step) :, :]) # Bx, C , H , W + .permute(2, 3, 0, 1) # H, W, Bx, C + .contiguous() + ) # y, x, b, c + + # Creation of neg. examples + ztwk_shuf, rand_index = self.create_negative_samples(k, skip_step, z, ztwk, cur_device) # y, x, b, c, n + + #### Compute x_W1 . c_t: + # context: b, c, H, W = x + context = ( + c[:, :, : -(k + skip_step), :].permute(2, 3, 0, 1).unsqueeze(-2) + ) # y (reduced H), x, b, 1, c + + log_fk_main = torch.matmul(context, ztwk.unsqueeze(-1)).squeeze( + -2 + ) # y, x, b, 1 + + log_fk_shuf = torch.matmul(context, ztwk_shuf).squeeze(-2) # y, x, b, n + + if self.contrast_mode=='multiclass': + log_fk, target = self.multiclass_contrasting(log_fk_main, log_fk_shuf, + batch_size, cur_device) # b, 1+n, y, x; b, y, x + elif self.contrast_mode=='binary': + log_fk, target = self.binary_contrasting(log_fk_main, log_fk_shuf, + batch_size, cur_device) # b, 1+1, n, y, x; b, n, y, x + elif self.contrast_mode=='logistic' or self.contrast_mode=='hinge': + if self.opt.no_pred: # Wpred * c is set to 1 (no prediction). i.e. fourth factor omitted in learning rule. In this case, the score function is equal to the sum of activations + log_fk_main = ztwk.sum(dim=-1).unsqueeze(-1) + context.sum(dim=-1) # y, x, b, 1 + log_fk_shuf = ztwk_shuf.sum(dim=-2) + context.sum(dim=-1).repeat(1, 1, 1, self.negative_samples) # y, x, b, n + + log_fk, target = self.logistic_contrasting(log_fk_main, log_fk_shuf, + batch_size, cur_device) # b, 1+1, n, y, x (both) + + if gating is None: + gate = None + else: + if self.average_feedback_gating: + gate = gating # already average over k + else: + gate = gating[k-1] # k-1 because k is in range(1,k_predictions+1) + + loss_k, loss_k_gated, gating_out_k = self.contrast_loss(self.opt, k, input=log_fk, target=target, gating=gate, which_update=self.which_update, save_vars=self.save_vars) + total_loss += loss_k + + if loss_k_gated is not None: + total_loss_gated += loss_k_gated + + gating_out.append(gating_out_k) + + if self.save_vars: + z_s = z[:, :, (k + skip_step) :, :].permute(2, 3, 0, 1).clone() # y (red.), x, b, c + torch.save((context.clone(), z_s.clone(), z_s[:, :, rand_index, :].clone(), rand_index), os.path.join(self.opt.model_path, 'saved_c_and_z_layer_'+str(self.opt.save_vars_for_update_calc)+'_k_'+str(k))) + + if self.save_vars: + if type(self.which_update) == str: + which_update_save = self.which_update + else: + which_update_save = self.which_update.copy() + torch.save(which_update_save, os.path.join(self.opt.model_path, 'saved_which_update_layer_'+str(self.opt.save_vars_for_update_calc))) + + total_loss /= self.k_predictions + total_loss_gated /= self.k_predictions + + return total_loss, total_loss_gated, gating_out + + + def multiclass_contrasting(self, log_fk_main, log_fk_shuf, + batch_size, cur_device): + """ contrasting all the negative examples at the same time via multi-class classification""" + log_fk = torch.cat((log_fk_main, log_fk_shuf), 3) # y, x, b, 1+n + log_fk = log_fk.permute(2, 3, 0, 1) # b, 1+n, y, x This is the shape expected by nll_loss + + log_fk = torch.softmax(log_fk, dim=1) + + target = torch.zeros( + (batch_size, log_fk.shape[-2], log_fk.shape[-1]), + dtype=torch.long, + device=cur_device, + ) # b, y, x + return log_fk, target # b, 1+n, y, x; b, y, x + + def _get_log_fk(self, log_fk_main, log_fk_shuf): + log_fk_main = log_fk_main.repeat(1, 1, 1, self.negative_samples) # y, x, b, n + + log_fk = torch.cat((log_fk_main.unsqueeze(-1), log_fk_shuf.unsqueeze(-1)), 4) # y, x, b, n, 1+1 + log_fk = log_fk.permute(2, 4, 3, 0, 1) # b, 1+1, n, y, x This is the shape expected by nll_loss + return log_fk + + def binary_contrasting(self, log_fk_main, log_fk_shuf, + batch_size, cur_device): + """ contrasting all the negative examples independently and later average over losses""" + log_fk = self._get_log_fk(log_fk_main, log_fk_shuf) # b, 1+1, n, y, x + + log_fk = torch.softmax(log_fk, dim=1) + target = torch.zeros( + (batch_size, self.negative_samples, log_fk.shape[-2], log_fk.shape[-1]), + dtype=torch.long, + device=cur_device, + ) # b, n, y, x + return log_fk, target # b, 1+1, n, y, x; b, n, y, x + + def logistic_contrasting(self, log_fk_main, log_fk_shuf, + batch_size, cur_device): + """ contrasting by doing binary logistic regression on pos. and neg. ex. separately and later average over losses""" + log_fk = self._get_log_fk(log_fk_main, log_fk_shuf) # b, 1+1, n, y, x + + zeros = torch.zeros( + (batch_size, self.negative_samples, log_fk.shape[-2], log_fk.shape[-1]), + dtype=torch.float32, + device=cur_device, + ) # b, n, y, x + ones = torch.ones( + (batch_size, self.negative_samples, log_fk.shape[-2], log_fk.shape[-1]), + dtype=torch.float32, + device=cur_device, + ) # b, n, y, x + target = torch.cat((ones.unsqueeze(1), zeros.unsqueeze(1)), 1) # b, 1+1, n, y, x + return log_fk, target # b, 1+1, n, y, x (both) + + def sample_negative_samples(self, ztwk, cur_device): + ztwk_shuf = ztwk.view( + ztwk.shape[0] * ztwk.shape[1] * ztwk.shape[2], ztwk.shape[3] + ) # y * x * b, c + rand_index = torch.randint( + ztwk_shuf.shape[0], # upper limit: y * x * b + (ztwk_shuf.shape[0] * self.negative_samples, 1), # shape: y * x * b * n, 1 + dtype=torch.long, + device=cur_device, + ) + # replicate neg. sample indices for all channels + rand_index = rand_index.repeat(1, ztwk_shuf.shape[1]) # y * x * b * n, c + + ztwk_shuf = torch.gather( + ztwk_shuf, dim=0, index=rand_index, out=None + ) # y * x * b * n, c + + ztwk_shuf = ztwk_shuf.view( + ztwk.shape[0], + ztwk.shape[1], + ztwk.shape[2], + self.negative_samples, + ztwk.shape[3], + ).permute( + 0, 1, 2, 4, 3 + ) # y, x, b, c, n + + return ztwk_shuf, rand_index + + def sample_negative_samples_locally(self, ztwk, cur_device, same_sampling_everywhere=False): + # ztwk: y, x, b, c + # same sampling (same sample from batch) for all locations + if same_sampling_everywhere: + rand_index = torch.randint(ztwk.shape[2], # upper limit: b + (ztwk.shape[2],), # shape: b, assumes n=1 neg. samples, + dtype=torch.long, + device=cur_device, + ) + ztwk_shuf = ztwk[:, :, rand_index, :] # y, x, b, c + # or different sampling (different sample from batch) for different locations: + else: + rand_index = torch.randint(ztwk.shape[2], # upper limit: b + (ztwk.shape[0], ztwk.shape[1], ztwk.shape[2]), # shape: y, x, b, assumes n=1 neg. samples, + dtype=torch.long, + device=cur_device, + ) + # replicate neg. sample indices for all channels + rand_index = rand_index.repeat(ztwk.shape[-1], 1, 1, 1).permute(1, 2, 3, 0) # y, x, b, c + ztwk_shuf = torch.gather( + ztwk, dim=2, index=rand_index, out=None + ) # y, x, b, c + + ztwk_shuf = ztwk_shuf.unsqueeze(-1) # y, x, b, c, n=1 + return ztwk_shuf, rand_index + + # Creation of neg. examples + def create_negative_samples(self, k, skip_step, z, ztwk, cur_device): + if self.current_rep_as_negative: # (unsuccesful) idea of using averaged activity over batch ("memory trace") + ztwk_context = ( + self.W_k[k - 1] + .forward(z[:, :, : -(k + skip_step), :]) # Bx, C , H , W + .permute(2, 3, 0, 1) # H, W, Bx, C + .contiguous() + ) # y, x, b, c (number of negative examples is set to n=1 in that case) + ztwk_shuf, rand_index = self.sample_negative_samples_locally(ztwk_context, cur_device, same_sampling_everywhere=self.sample_negs_locally_same_everywhere) + else: + if self.sample_negs_locally: + ztwk_shuf, rand_index = self.sample_negative_samples_locally(ztwk, cur_device, same_sampling_everywhere=self.sample_negs_locally_same_everywhere) + else: + ztwk_shuf, rand_index = self.sample_negative_samples(ztwk, cur_device) + + return ztwk_shuf, rand_index + +############################################################################################################## +# Contrastive loss functions + +class ExpNLLLoss(_WeightedLoss): + def __init__(self, weight=None, size_average=None, ignore_index=-100, + reduce=None, reduction='mean'): + super(ExpNLLLoss, self).__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, opt, k, input, target, gating=None, which_update='both', save_vars=False): + if which_update != 'both': + raise ValueError("which_update must be both for ExpNLLLoss") + x = torch.log(input + 1e-11) + loss = F.nll_loss(x, target, weight=self.weight, ignore_index=self.ignore_index, + reduction=self.reduction) + return loss, None, None + + +class MyBCEWithLogitsLoss(_WeightedLoss): + def __init__(self): + super(MyBCEWithLogitsLoss, self).__init__() + self.loss_func = nn.BCEWithLogitsLoss() + + def forward(self, opt, k, input, target, gating=None, which_update='both', save_vars=False): + if which_update != 'both': + raise ValueError("which_update must be both for BCEWithLogitsLoss") + loss = self.loss_func(input, target) + return loss, None, None + +class HingeLoss(_WeightedLoss): + def __init__(self): + super(HingeLoss, self).__init__() + + def forward(self, opt, k, input, target, gating=None, which_update='both', save_vars=False): # b, 1+1, n, y, x (both) + # Take care: pos sample appears N times for N neg. examples + # gating should be 2 1-dim vectors (for pos and neg samples) with length b containing weights/gating values for each image in batch + + def _normalise_gating(g): + g_mean = g.mean() + g = g - g_mean + g = g / g_mean + g = torch.sigmoid(3 * g) + g /= g.shape[0] # this normalisation makes the sum in matmul in the gated loss an (unnormed) average + return g + + def _add_losses(loss_pos, loss_neg, which_update, cur_device): + l = 0 + if type(which_update) == str: + if which_update == 'both': # default case + l = loss_pos.mean() + loss_neg.mean() + else: + for loss, c in zip([loss_pos, loss_neg], ['pos', 'neg',]): + ind = (which_update == c) | (which_update == 'both') + if sum(ind) > 0: # exclude empty sets which lead no NaN in loss + l += torch.masked_select(loss, torch.tensor(ind.tolist()).to(cur_device)).mean() + return l + + cur_device = input.get_device() + + scores_pos = input[:,0,:,:,:] # b, n, y, x + scores_neg = input[:,1,:,:,:] # b, n, y, x + + ones = torch.ones(size=scores_pos.shape, dtype=torch.float32, device=cur_device) + zeros = torch.zeros(size=scores_neg.shape, dtype=torch.float32, device=cur_device) + + if opt.no_gamma: # gamma (factor which sets the opposite sign of the update for pos and neg samples and sets gating) is set to 1. i.e. third factor omitted in learning rule + losses_pos = ones - scores_pos + losses_neg = ones - scores_neg + else: + losses_pos = torch.where(scores_pos < ones, ones - scores_pos, zeros) # b, n, y, x + losses_neg = torch.where(scores_neg > - ones, ones + scores_neg, zeros) # b, n, y, x + + if save_vars: + torch.save((losses_pos, losses_neg), os.path.join(opt.model_path, 'saved_losses_layer_'+str(opt.save_vars_for_update_calc)+'_k_'+str(k))) + + # if gating values are given, take weighted sum before averaging over remaining dimensions + if gating == None: + loss_gated = None + else: + losses_pos_gated = torch.matmul(losses_pos.permute(1,2,3,0), gating[0]) + losses_neg_gated = torch.matmul(losses_neg.permute(1,2,3,0), gating[1]) + loss_gated = _add_losses(losses_pos_gated, losses_neg_gated, which_update, cur_device) + + losses_pos_per_sample = losses_pos.mean(dim=(-1,-2,-3)) + losses_neg_per_sample = losses_neg.mean(dim=(-1,-2,-3)) + + # detach gating such that no gradient of original loss is back-propagated!, + # clone is important so that later normalisation of gating does not influence loss + gating_pos = losses_pos_per_sample.clone().detach() + gating_neg = losses_neg_per_sample.clone().detach() + + gating_pos = _normalise_gating(gating_pos) + gating_neg = _normalise_gating(gating_neg) + + gating_out = [gating_pos, gating_neg] + + loss = _add_losses(losses_pos_per_sample, losses_neg_per_sample, which_update, cur_device) # average over remaining batch dimension + + return loss, loss_gated, gating_out diff --git a/vision/GreedyInfoMax/vision/models/Supervised_Loss.py b/vision/GreedyInfoMax/vision/models/Supervised_Loss.py new file mode 100755 index 0000000..e5dbf75 --- /dev/null +++ b/vision/GreedyInfoMax/vision/models/Supervised_Loss.py @@ -0,0 +1,52 @@ +import torch.nn as nn +import torch + +from GreedyInfoMax.utils import utils + +class Supervised_Loss(nn.Module): + def __init__(self, opt, hidden_dim, calc_accuracy): + super(Supervised_Loss, self).__init__() + + self.opt = opt + + self.pool = None + self.hidden_dim = hidden_dim + self.calc_accuracy = calc_accuracy + + # create linear classifier + if opt.dataset == "stl10": + n_classes = 10 + else: + raise Exception("Other datasets are not implemented yet") + + self.linear_classifier = nn.Sequential( + nn.Linear(self.hidden_dim, n_classes) + ).to(self.opt.device) + + self.classification_loss = nn.CrossEntropyLoss() + + self.label_num = 1 + + + def forward(self, z, label): + total_loss, accuracies = self.calc_supervised_loss( + z, label + ) + return total_loss, accuracies + + + def calc_supervised_loss(self, z, labels): + # forward pass + z = nn.functional.adaptive_avg_pool2d(z, 1).squeeze() + + output = self.linear_classifier(z) + + loss = self.classification_loss(output, labels) + + accuracy = torch.zeros(1) + + # calculate accuracy + if self.calc_accuracy: + accuracy[0], = utils.accuracy(output.data, labels, topk=(1,)) + + return loss, accuracy diff --git a/vision/GreedyInfoMax/vision/models/VGG_like_Encoder.py b/vision/GreedyInfoMax/vision/models/VGG_like_Encoder.py new file mode 100755 index 0000000..49c58e6 --- /dev/null +++ b/vision/GreedyInfoMax/vision/models/VGG_like_Encoder.py @@ -0,0 +1,345 @@ +# Partly taken from (30 July 2020) +# https://pytorch.org/docs/stable/_modules/torchvision/models/vgg.html#vgg11 +import torch.nn as nn +import torch.nn.functional as F +import torch +import numpy as np +import os + +from GreedyInfoMax.vision.models import InfoNCE_Loss, Supervised_Loss +from GreedyInfoMax.utils import model_utils + +class VGG_like_Encoder(nn.Module): + def __init__( + self, + opt, + block_idx, + blocks, + in_channels, + patch_size=16, + overlap_factor=2, + calc_loss=False, + ): + super(VGG_like_Encoder, self).__init__() + self.encoder_num = block_idx + self.opt = opt + + self.save_vars = self.opt.save_vars_for_update_calc == block_idx+1 + + # Layer + self.model = self.make_layers(blocks[block_idx], in_channels) + + # Params + self.calc_loss = calc_loss + + self.overlap = overlap_factor + self.increasing_patch_size = self.opt.increasing_patch_size + if self.increasing_patch_size: # This is experimental... take care, this must be synced with architecture, i.e. number and position of downsampling layers (stride 2, e.g. pooling) + if self.overlap != 2: + raise ValueError("if --increasing_patch_size is true, overlap(_factor) has to be equal 2") + patch_sizes = [4, 4, 8, 8, 16, 16] + self.patch_size_eff = patch_sizes[block_idx] + self.max_patch_size = max(patch_sizes) + high_level_patch_sizes = [4, 4, 4, 4, 4, 2] + self.patch_size = high_level_patch_sizes[block_idx] + else: + self.patch_size = patch_size + + reduced_patch_pool_sizes = [4, 4, 3, 3, 2, 1] + if opt.reduced_patch_pooling: + self.patch_average_pool_out_dim = reduced_patch_pool_sizes[block_idx] + else: + self.patch_average_pool_out_dim = 1 + + self.predict_module_num = self.opt.predict_module_num + self.extra_conv = self.opt.extra_conv + self.inpatch_prediction = self.opt.inpatch_prediction + self.inpatch_prediction_limit = self.opt.inpatch_prediction_limit + self.asymmetric_W_pred = self.opt.asymmetric_W_pred + + if opt.gradual_prediction_steps: + prediction_steps = min(block_idx+1, self.opt.prediction_step) + else: + prediction_steps = self.opt.prediction_step + + def get_last_index(block): + if block[-1] == 'M': + last_ind = -2 + else: + last_ind = -1 + return last_ind + + last_ind = get_last_index(blocks[block_idx]) + self.in_planes = blocks[block_idx][last_ind] + # in_channels_loss: z, out_channels: c + if self.predict_module_num=='-1' or self.predict_module_num=='both': + if self.encoder_num == 0: # exclude first module + in_channels_loss = self.in_planes + if opt.reduced_patch_pooling: + in_channels_loss *= reduced_patch_pool_sizes[block_idx] ** 2 + else: + last_ind_block_below = get_last_index(blocks[block_idx-1]) + in_channels_loss = blocks[block_idx-1][last_ind_block_below] + if opt.reduced_patch_pooling: + in_channels_loss *= reduced_patch_pool_sizes[block_idx-1] ** 2 + else: + in_channels_loss = self.in_planes + if opt.reduced_patch_pooling: + in_channels_loss *= reduced_patch_pool_sizes[block_idx] ** 2 + + # Optional extra conv layer to increase rec. field size + if self.extra_conv and self.encoder_num < 3: + self.extra_conv_layer = nn.Conv2d(self.in_planes, self.in_planes, stride=3, kernel_size=3, padding=1) + + # in_channels_loss: z, out_channels: c + if self.predict_module_num == '-1b': + if self.encoder_num == len(blocks)-1: # exclude last module + out_channels = self.in_planes + if opt.reduced_patch_pooling: + out_channels *= reduced_patch_pool_sizes[block_idx] ** 2 + else: + last_ind_block_above = get_last_index(blocks[block_idx+1]) + out_channels = blocks[block_idx+1][last_ind_block_above] + if opt.reduced_patch_pooling: + out_channels *= reduced_patch_pool_sizes[block_idx+1] ** 2 + else: + out_channels = self.in_planes + if opt.reduced_patch_pooling: + out_channels *= reduced_patch_pool_sizes[block_idx] ** 2 + + + # Loss module; is always present, but only gets used when training GreedyInfoMax modules + # in_channels_loss: z, out_channels: c + if self.opt.loss == 0: + self.loss = InfoNCE_Loss.InfoNCE_Loss( + opt, + in_channels=in_channels_loss, # z + out_channels=out_channels, # c + prediction_steps=prediction_steps, + save_vars=self.save_vars + ) + if self.predict_module_num == 'both': + self.loss_same_module = InfoNCE_Loss.InfoNCE_Loss( + opt, + in_channels=in_channels_loss, + out_channels=in_channels_loss, # on purpose, cause in_channels_loss is layer below + prediction_steps=prediction_steps + ) + if self.asymmetric_W_pred: + self.loss_mirror = InfoNCE_Loss.InfoNCE_Loss( + opt, + in_channels=in_channels_loss, + out_channels=out_channels, + prediction_steps=prediction_steps + ) + elif self.opt.loss == 1: + self.loss = Supervised_Loss.Supervised_Loss(opt, in_channels_loss, True) + else: + raise Exception("Invalid option") + + + # Optional recurrent weights, Experimental! + if self.opt.inference_recurrence == 1 or self.opt.inference_recurrence == 3: # 1 - lateral recurrence within layer + self.recurrent_weights = nn.Conv2d(self.in_planes, self.in_planes, 1, bias=False) + if self.opt.inference_recurrence == 2 or self.opt.inference_recurrence == 3: # 2 - feedback recurrence, 3 - both, lateral and feedback recurrence + if self.encoder_num < len(blocks)-1: # exclude last module + last_ind_block_above = get_last_index(blocks[block_idx+1]) + rec_dim_block_above = blocks[block_idx+1][last_ind_block_above] + self.recurrent_weights_fb = nn.Conv2d(rec_dim_block_above, self.in_planes, 1, bias=False) + + if self.opt.weight_init: + raise NotImplementedError("Weight init not implemented for vgg") + + def make_layers(self, block, in_channels, batch_norm=False, inplace=False): + layers = [] + for v in block: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=inplace)] + else: + layers += [conv2d, nn.ReLU(inplace=inplace)] + in_channels = v + return nn.Sequential(*layers) + + + def forward(self, x, reps, t, n_patches_y, n_patches_x, label): + # x: either input dims b, c, Y, X or (if coming from lower module which did unfolding, as variable z): b * n_patches_y * n_patches_x, c, y, x + + # Input preparation, i.e unfolding into patches. Usually only needed for first module. More complicated for experimental increasing_patch_size option. + if self.encoder_num in [0,2,4]: # [0,2,4,5] + # if increasing_patch_size is enabled, this has to be in sync with architecture and intended patch_size for respective module: + # for every layer that increases the patch_size, the extra downsampling + unfolding has to be done! + if self.encoder_num > 0 and self.increasing_patch_size: + # undo unfolding of the previous module + s1 = x.shape + x = x.reshape(-1, n_patches_y, n_patches_x, s1[1], s1[2], s1[3]) # b, n_patches_y, n_patches_x, c, y, x + # downsampling to get rid of the overlaps between paches of the previous module + x = x[:,::2,::2,:,:,:] # b, n_patches_x_red, n_patches_y_red, c, y, x. + s = x.shape + x = x.permute(0,3,2,5,1,4).reshape(s[0],s[3],s[2],s[5],s[1]*s[4]).permute(0,1,4,2,3).reshape(s[0],s[3],s[1]*s[4],s[2]*s[5]) # b, c, Y, X + + if self.encoder_num == 0 or self.increasing_patch_size: + x = ( # b, c, y, x + x.unfold(2, self.patch_size, self.patch_size // self.overlap) # b, c, n_patches_y, x, patch_size + .unfold(3, self.patch_size, self.patch_size // self.overlap) # b, c, n_patches_y, n_patches_x, patch_size, patch_size + .permute(0, 2, 3, 1, 4, 5) # b, n_patches_y, n_patches_x, c, patch_size, patch_size + ) + n_patches_y = x.shape[1] + n_patches_x = x.shape[2] + x = x.reshape( + x.shape[0] * x.shape[1] * x.shape[2], x.shape[3], x.shape[4], x.shape[5] + ) # b * n_patches_y * n_patches_x, c, patch_size, patch_size + + # Main encoding step + # forward through self.model is split into (conv)/(nonlin + pool) due to (optional) recurrence + # assuming arch = [128, 256, 'M', 256, 512, 'M', 1024, 'M', 1024, 'M'] + if self.opt.inference_recurrence > 0: # in case of recurrence + if self.opt.model_splits == 6: + split_ind = 1 + elif self.opt.model_splits == 3 or self.opt.model_splits == 1: + split_ind = -2 + else: + raise NotImplementedError("Recurrence is only implemented for model_splits = 1, 3 or 6") + else: # without recurrence split does not really matter, arbitrily choose 1 + split_ind = 1 + + if self.save_vars: # save input for (optional) manual update calculation + torch.save(x, os.path.join(self.opt.model_path, 'saved_input_layer_'+str(self.opt.save_vars_for_update_calc))) + + # 1. Apply encoding weights of model (e.g. conv2d layer) + z = self.model[:split_ind](x) # b * n_patches_y * n_patches_x, c, y, x + + # 2. Add (optional) recurrence if present + # expand dimensionality if rec comes from layer after one or several (2x2 strided, i.e. downsampled) pooling layer(s). tensor.repeat_interleave() would do, but not available in pytorch 1.0.0 + def expand_2_by_2(rec): + srec = rec.shape + return ( + rec.unfold(2,1,1).repeat((1,1,1,1,2)).permute(0,1,2,4,3).reshape(srec[0],srec[1],2*srec[2],srec[3]) + .unfold(3,1,1).repeat((1,1,1,1,2)).reshape(srec[0],srec[1],2*srec[2],2*srec[3]) + ) + if t > 0: # only apply rec if iteration is not the first one (is not entered if recurrence is off since then t = 0) + if self.opt.inference_recurrence == 1 or self.opt.inference_recurrence == 3: # 1 - lateral recurrence within layer + rec = self.recurrent_weights(reps[self.encoder_num].clone().detach()) # Detach input to implement e-prop like BPTT + while z.shape != rec.shape: # if rec comes from strided pooling layer + rec = expand_2_by_2(rec) + z += rec + if self.opt.inference_recurrence == 2 or self.opt.inference_recurrence == 3: # 2 - feedback recurrence, 3 - both, lateral and feedback recurrence + if self.encoder_num < len(reps)-1: # exclude last module + rec_fb = self.recurrent_weights_fb(reps[self.encoder_num+1].clone().detach()) # Detach input to implement e-prop like BPTT + while z.shape != rec_fb.shape: # if rec comes from strided pooling layer + rec_fb = expand_2_by_2(rec_fb) + z += rec_fb + + # 3. Apply nonlin and 'rest' of model (e.g. ReLU, MaxPool etc...) + z = self.model[split_ind:](z) # b * n_patches_y * n_patches_x, c, y, x + + # Optional extra conv layer with downsampling (stride > 1) here to increase receptive field size ### + if self.extra_conv and self.encoder_num < 3: + dec = self.extra_conv_layer(z) + dec = F.relu(dec, inplace=False) + else: + dec = z + + # Optional in-patch prediction + # if opt: change CPC task to smaller scale prediction (within patch -> smaller receptive field) + # by extra unfolding + "cropping" (to avoid overweighing lower layers and memory overflow) + if self.inpatch_prediction and self.encoder_num < self.inpatch_prediction_limit: + extra_patch_size = [2 for _ in range(self.inpatch_prediction_limit)] + extra_patch_steps = [1 for _ in range(self.inpatch_prediction_limit)] + + dec = dec.reshape(-1, n_patches_x, n_patches_y, dec.shape[1], dec.shape[2], dec.shape[3]) # b, n_patches_y, n_patches_x, c, y, x + # random "cropping"/selecting of patches that will be extra unfolded + extra_crop_size = [n_patches_x // 2 for _ in range(self.inpatch_prediction_limit)] + inds = np.random.randint(0, n_patches_x - extra_crop_size[self.encoder_num], 2) + dec = dec[:, inds[0]:inds[0]+extra_crop_size[self.encoder_num], inds[1]:inds[1]+extra_crop_size[self.encoder_num],:,:,:] + + # extra unfolding + dec = ( + dec.unfold(4, extra_patch_size[self.encoder_num], extra_patch_steps[self.encoder_num]) + .unfold(5, extra_patch_size[self.encoder_num], extra_patch_steps[self.encoder_num]) # b, n_patches_y, n_patches_x, c, n_extra_patches, n_extra_patches, extra_patch_size, extra_patch_size + .permute(0, 1, 2, 4, 5, 3, 6, 7) # b, n_patches_y(_reduced), n_patches_x(_reduced), n_extra_patches, n_extra_patches, c, extra_patch_size, extra_patch_size + ) + n_extra_patches = dec.shape[3] + dec = dec.reshape(dec.shape[0] * dec.shape[1] * dec.shape[2] * dec.shape[3] * dec.shape[4], dec.shape[5], dec.shape[6], dec.shape[7]) + # b * n_patches_y(_reduced) * n_patches_x(_reduced) * n_extra_patches * n_extra_patches, c, extra_patch_size, extra_patch_size + + # Pool over patch + # in original CPC/GIM, pooling is done over whole patch, i.e. output shape 1 by 1 + out = F.adaptive_avg_pool2d(dec, self.patch_average_pool_out_dim) # b * n_patches_y(_reduced) * n_patches_x(_reduced) (* n_extra_patches * n_extra_patches), c, x_pooled, y_pooled + # Flatten over channel and pooled patch dimensions x_pooled, y_pooled: + out = out.reshape(out.shape[0], -1) # b * n_patches_y(_reduced) * n_patches_x(_reduced) (* n_extra_patches * n_extra_patches), c * y_pooled * x_pooled + + if self.inpatch_prediction and self.encoder_num < self.inpatch_prediction_limit: + n_p_x, n_p_y = n_extra_patches, n_extra_patches + else: + n_p_x, n_p_y = n_patches_x, n_patches_y + + out = out.reshape(-1, n_p_y, n_p_x, out.shape[1]) # b, n_patches_y, n_patches_x, c * y_pooled * x_pooled OR b * n_patches_y(_reduced) * n_patches_x(_reduced), n_extra_patches, n_extra_patches, c * y_pooled * x_pooled + out = out.permute(0, 3, 1, 2).contiguous() # b, c * y_pooled * x_pooled, n_patches_y, n_patches_x OR b * n_patches_y(_reduced) * n_patches_x(_reduced), c * y_pooled * x_pooled, n_extra_patches, n_extra_patches + + return out, z, n_patches_y, n_patches_x + + # crop feature map such that the loss always predicts/averages over same amount of patches (as the last one) + def random_spatial_crop(self, out, n_patches_x, n_patches_y): + n_patches_x_crop = n_patches_x // (self.max_patch_size // self.patch_size_eff) + n_patches_y_crop = n_patches_y // (self.max_patch_size // self.patch_size_eff) + if n_patches_x == n_patches_x_crop: + posx = 0 + else: + posx = np.random.randint(0, n_patches_x - n_patches_x_crop + 1) + if n_patches_y == n_patches_y_crop: + posy = 0 + else: + posy = np.random.randint(0, n_patches_y - n_patches_y_crop + 1) + out = out[:, :, posy:posy+n_patches_y_crop, posx:posx+n_patches_x_crop] + return out + + def evaluate_loss(self, outs, cur_idx, label, gating=None): + accuracy = torch.zeros(1) + gating_out = None + if self.calc_loss and self.opt.loss == 0: + # Special cases of predicting module below or same module and below ('both') + if self.predict_module_num=='-1' or self.predict_module_num=='both': # gating not implemented here! + if self.asymmetric_W_pred: + raise NotImplementedError("asymmetric W not implemented yet for predicting lower layers!") + if self.encoder_num==0: # normal loss for first module + loss, loss_gated, gating_out = self.loss(outs[cur_idx], outs[cur_idx], gating=gating) # z, c + else: + loss, loss_gated, _ = self.loss(outs[cur_idx-1], outs[cur_idx]) # z, c + if self.predict_module_num=='both': + loss_intralayer, _, _ = self.loss_same_module(outs[cur_idx], outs[cur_idx]) + loss = 0.5 * (loss + loss_intralayer) + + elif self.predict_module_num=='-1b': + if self.asymmetric_W_pred: + raise NotImplementedError("asymmetric W not implemented yet for predicting lower layers!") + if self.encoder_num == len(outs)-1: # normal loss for last module + loss, loss_gated, gating_out = self.loss(outs[cur_idx], outs[cur_idx], gating=gating) # z, c + else: + loss, loss_gated, _ = self.loss(outs[cur_idx], outs[cur_idx+1]) # z, c + # Normal case for prediction within same layer + else: + if self.asymmetric_W_pred: # u = z*W_pred*c -> u = drop_grad(z)*W_pred1*c + z*W_pred2*drop_grad(c) + if self.opt.contrast_mode != 'hinge': + raise ValueError("asymmetric_W_pred only implemented for hinge contrasting!") + + loss, loss_gated, _ = self.loss(outs[cur_idx], outs[cur_idx].clone().detach(), gating=gating) # z, detach(c) + + + loss_mirror, loss_mirror_gated, _ = self.loss_mirror(outs[cur_idx].clone().detach(), outs[cur_idx], gating=gating) # detach(z), c + + loss = loss + loss_mirror + loss_gated = loss_gated + loss_mirror_gated + else: + loss, loss_gated, gating_out = self.loss(outs[cur_idx], outs[cur_idx], gating=gating) # z, c + + elif self.calc_loss and self.opt.loss == 1: # supervised loss + loss, accuracy = self.loss(outs[cur_idx], label) + loss_gated, gating_out = -1, -1 + else: # only forward pass for downstream classification + loss, loss_gated, accuracy, gating_out = None, None, None, None + + return loss, loss_gated, accuracy, gating_out diff --git a/vision/GreedyInfoMax/vision/models/__init__.py b/vision/GreedyInfoMax/vision/models/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/vision/GreedyInfoMax/vision/models/load_vision_model.py b/vision/GreedyInfoMax/vision/models/load_vision_model.py new file mode 100755 index 0000000..c59499f --- /dev/null +++ b/vision/GreedyInfoMax/vision/models/load_vision_model.py @@ -0,0 +1,49 @@ +import torch + +from GreedyInfoMax.vision.models import FullModel, ClassificationModel +from GreedyInfoMax.utils import model_utils + + +def load_model_and_optimizer(opt, num_GPU=None, reload_model=False, calc_loss=True): + + model = FullModel.FullVisionModel( + opt, calc_loss + ) + + optimizer = [] + if opt.model_splits == 1: + optimizer.append(torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)) + elif opt.model_splits >= 2: + # use separate optimizer for each module, so gradients don't get mixed up + for idx, layer in enumerate(model.encoder): + optimizer.append(torch.optim.Adam(layer.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)) + else: + raise NotImplementedError + # Note: module.parameters() acts recursively by default and adds all parameters of submodules as well + + model, num_GPU = model_utils.distribute_over_GPUs(opt, model, num_GPU=num_GPU) + + model, optimizer = model_utils.reload_weights( + opt, model, optimizer, reload_model=reload_model + ) + + return model, optimizer + +def load_classification_model(opt): + if opt.in_channels == None: + in_channels = 1024 + else: + in_channels = opt.in_channels + + if opt.dataset == "stl10" or opt.dataset == "cifar10": + num_classes = 10 + elif opt.dataset == "cifar100": + num_classes = 100 + else: + raise Exception("Invalid option") + + classification_model = ClassificationModel.ClassificationModel( + in_channels=in_channels, num_classes=num_classes, + ).to(opt.device) + + return classification_model diff --git a/vision/GreedyInfoMax/vision/visualise.py b/vision/GreedyInfoMax/vision/visualise.py new file mode 100755 index 0000000..b799692 --- /dev/null +++ b/vision/GreedyInfoMax/vision/visualise.py @@ -0,0 +1,294 @@ + +# Call as downstream classification, e.g. +# python -m GreedyInfoMax.vision.visualise --model_path ./logs/your_simulation --model_num 299 --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --module_num 6 --batch_size 100 + +################################################################################ + +import torch +import numpy as np +from numpy.random import choice +import time +import os +import code +import sklearn +from IPython import embed +import matplotlib.pyplot as plt +import torchvision.transforms as transforms + +## own modules +from GreedyInfoMax.vision.data import get_dataloader +from GreedyInfoMax.vision.arg_parser import arg_parser +from GreedyInfoMax.vision.models import load_vision_model +from GreedyInfoMax.utils import logger, utils + + +def load_model_and_data(opt): + add_path_var = "linear_model" + + arg_parser.create_log_path(opt, add_path_var=add_path_var) + opt.training_dataset = "train" + + # load pretrained model + # cannot switch opt.reduced_patch_pooling = False here because otherwise W_preds sizes don't match + context_model, _ = load_vision_model.load_model_and_optimizer( + opt, reload_model=True, calc_loss=False + ) + context_model.module.switch_calc_loss(False) + + ## model_type=2 is supervised model which trains entire architecture; otherwise just extract features + if opt.model_type != 2: + context_model.eval() + + if opt.module_num==-1: + print("CAREFUL! Training classifier directly on input image! Model is ignored and returns the (flattened) input images!") + + _, _, train_loader, _, test_loader, _ = get_dataloader.get_dataloader(opt) + + return context_model, train_loader, test_loader + +def get_representations(opt, model, data_loader, module=None, reload=True): + if type(data_loader.dataset.transform.transforms[0].transforms[0]) != transforms.transforms.CenterCrop: + raise Exception("Data loader should use deterministic cropping (Center Crop) for image patch visualisation!") + + if module == None: + module = opt.module_num + + if reload: + print("Reload representations...") + (inputs, reps, targets) = torch.load(os.path.join(opt.model_path, 'saved_reps_module'+str(module)), map_location=torch.device('cpu')) + else: + print("Calculate representations...") + inputs = [] + reps = [] + targets = [] + for step, (img, target) in enumerate(data_loader): + print("batch number: ", step, " of ", len(data_loader)) + model_input = img.to(opt.device) + if opt.model_type == 2: ## fully supervised training + _, _, z = model(model_input) + else: + with torch.no_grad(): + _, _, _, z, _ = model(model_input, target, n=module) + + inputs.append(model_input) + reps.append(z.detach()) + targets.append(target) + + inputs = torch.cat(inputs).cpu() + reps = torch.cat(reps).cpu() + targets = torch.cat(targets).cpu() + + torch.save((inputs, reps, targets), os.path.join(opt.model_path, 'saved_reps_module'+str(module))) + + return inputs, reps, targets + +############################################################################################################## +# Visualisation of learned "Manifold" by t-SNE embedding + +def tSNE(opt, inputs, reps, targets, class_names, n_points = None): + print("Doing t-SNE...") + + n_samples = targets.shape[0] + if n_points == None: + n_points = n_samples + + d_inputs = inputs.reshape(n_samples, -1) + + reps_m = torch.mean(reps,(2,3)) # spatial mean pooling + d_reps = reps_m.reshape(n_samples, -1) + + tsne_inputs = sklearn.manifold.TSNE(perplexity = 50) + tsne_reps = sklearn.manifold.TSNE(perplexity = 50) + + #t_inputs = tsne_inputs.fit_transform(d_inputs[:n_points,:]) + t_reps = tsne_reps.fit_transform(d_reps[:n_points,:]) + + #tSNE_plot(opt, t_inputs, targets[:n_points], class_names, fig_name_ext = 'input') + tSNE_plot(opt, t_reps, targets[:n_points], class_names) + +def tSNE_plot(opt, t_data, targets, class_names, fig_name_ext = '', markersize = 2, plot_legend = False): + plt.figure() + for class_index in range(10): # loop over classes + inds = targets == class_index + if sum(inds) > 0: # exclude empty sets + t_data_plot = t_data[inds, :] + plt.scatter(t_data_plot[:,0],t_data_plot[:,1], label = class_names[class_index], s=markersize) + plt.axis('off') + if plot_legend: + plt.legend(markerscale=3) + plt.savefig(os.path.join(opt.model_path, 'tSNE_module'+str(opt.module_num)+fig_name_ext+'.pdf')) + +############################################################################################################## +# Visualisation of learned "Manifold" by looking at neighbour encodings in feature space + +def unravel_index(indices: torch.LongTensor, shape) -> torch.LongTensor: + """Converts flat indices into unraveled coordinates in a target shape. + This is a `torch` implementation of `numpy.unravel_index`. + Args: + indices: A tensor of (flat) indices, (*, N). + shape: The targeted shape, (D,). + + Returns: + The unraveled coordinates, (*, N, D). + """ + + coord = [] + + for dim in reversed(shape): + coord.append(indices % dim) + indices = indices // dim + + coord = torch.stack(coord[::-1], dim=-1) + + return coord + +def patch_neighbouranalysis(opt, imgs, reps, + n_examples = 5, n_neighbours = 10, patch_size = 16, patch_spacing = 8, center_crop_margin = 16, do_savefig = False): + + n_patches = reps.shape[-1] + patch_coords = choice([i for i in range(n_patches)], size = (n_examples, 2), replace=True) + + ts, neighbours_list = find_neighbours(reps, n_examples, n_neighbours, patch_coords) + + plot_patches(opt, imgs, patch_coords, ts, neighbours_list, n_examples, n_neighbours, + patch_size = patch_size, patch_spacing = patch_spacing, center_crop_margin = center_crop_margin, do_savefig = do_savefig) + + +def find_neighbours(reps, n_examples, n_neighbours, patch_coords): + # reps: b, c, x, y + s = reps.shape + + ts = [] + neighbours_list = [] + for i in range(n_examples): + t = np.random.randint(s[0]) + ts.append(t) + rep_t = reps[t, :, patch_coords[i,0], patch_coords[i,1]].squeeze() # c (reference patch) + + rep = reps.permute(1, 0, 2, 3).reshape(s[1], -1).permute(1,0) # b*x*y, c (flattened patch list) + + l2_dif = torch.norm(rep - rep_t, dim=1) # b*x*y + neighbours_flattened = l2_dif.argsort() + neighbours_array = unravel_index(neighbours_flattened, (s[0], s[2], s[3])) # b*x*y, 3 + neighbours = neighbours_array[1:n_neighbours+1] # n_neighbours, 3 (time, x, y) (closest neighbours, first excluded since same patch) + + neighbours_list.append(neighbours) + + return ts, neighbours_list + + +def plot_patches(opt, imgs, patch_coords, ts, neighbours_list, n_examples, n_neighbours, + patch_size = 16, patch_spacing = 8, center_crop_margin = 8, plot_reference_patch = True, crop = False, do_savefig = False): + # neighbours_list is list of 2dim arrays, each of which with dimensions: n_neighbours, 3 (time, x, y) + + def _add_patch_frame(img, p_coord, p_size, p_space, margin, color = 'black', extra_line_width = 2): + def _set_pixel_values(img, c, value, p_coord, p_size, p_space, margin, extra_line_width): + x0 = margin+p_coord[0]*p_space + y0 = margin+p_coord[1]*p_space + img[x0:x0+p_size+extra_line_width, y0:y0+extra_line_width, c] = value + img[x0:x0+p_size+extra_line_width, y0+p_size:y0+p_size+extra_line_width, c] = value + img[x0:x0+extra_line_width, y0:y0+p_size+extra_line_width, c] = value + img[x0+p_size:x0+p_size+extra_line_width, y0:y0+p_size+extra_line_width, c] = value + return img + + img = _set_pixel_values(img, [0,1,2], 0, p_coord, p_size, p_space, margin, extra_line_width) # black frame + if color == 'red': + _set_pixel_values(img, [0], 255, p_coord, p_size, p_space, margin, extra_line_width) # red frame + + return img + + n_plots = n_neighbours + if plot_reference_patch: # first one is reference patch itself + n_plots += 1 + + imgs_list = [] + for ex in range(n_examples): + imgs_select = [] + + if plot_reference_patch: + img = imgs[ts[ex], :, :, :].copy().transpose(1,2,0) # full, uncropped color image x, y, c + img = _add_patch_frame(img, patch_coords[ex, :], patch_size, patch_spacing, center_crop_margin, color = 'black') + imgs_select.append(img) + for n in range(n_neighbours): + img = imgs[neighbours_list[ex][n][0], :, :, :].copy().transpose(1,2,0) # full, uncropped color image x, y, c + img = _add_patch_frame(img, neighbours_list[ex][n][1:], patch_size, patch_spacing, center_crop_margin, color = 'red') + imgs_select.append(img) + + imgs_list.append(imgs_select) + + fig, axes = plt.subplots(nrows=n_examples, ncols=n_plots, sharex=True, sharey=True) # , gridspec_kw={'wspace': 0.05}) + fig.set_size_inches(10, 10) + for i in range(n_examples): + for j in range(n_plots): + ax = axes[i][j] + if crop: + ax.imshow(imgs_list[i][j][center_crop_margin:-center_crop_margin, center_crop_margin:-center_crop_margin]) + else: + ax.imshow(imgs_list[i][j]) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect('equal') + + if do_savefig: + plt.savefig(os.path.join(opt.model_path, 'patch_visualisation_module'+str(opt.module_num)+'.pdf')) + + +############################################################################################################## +# Visualisation of neuron receptive fields by plotting maximally activating patches + +def max_activating_patches(opt, imgs, reps, + n_neurons = 5, n_max = 8, patch_size = 16, patch_spacing = 8, center_crop_margin = 16, do_savefig = False): + # reps: b, c, x, y + s = reps.shape + + neuron_inds = choice([i for i in range(s[1])], size = (n_neurons), replace=False) + responses = reps[:, neuron_inds, :, :].permute(1, 0, 2, 3).reshape(n_neurons, -1) # n_neurons, b*x*y + + patches_list = [] + for n in range(n_neurons): + inds_flattened = responses[n].argsort(descending=True) # b*x*y (sorted by value, largest first) + inds = unravel_index(inds_flattened, (s[0], s[2], s[3])) # b*x*y, 3 (time, x, y) + patches = inds[:n_max].clone() # n_max, 3 (max. activating patches) + + ctr = 0 + tns = [] + tns.append(patches[0][0]) + for np in range(1,n_max): + tnp = patches[np][0] + while tnp in tns: + ctr += 1 + patches[np][:] = inds[n_max + ctr] + tnp = patches[np][0] + tns.append(tnp) + + + patches_list.append(patches) + + plot_patches(opt, imgs, None, None, patches_list, n_neurons, n_max, + patch_size = patch_size, patch_spacing = patch_spacing, center_crop_margin = center_crop_margin, + plot_reference_patch = False, crop = True, do_savefig = do_savefig) + + # embed() + +############################################################################################################## + +if __name__ == "__main__": + + opt = arg_parser.parse_args() + + model, _, test_loader = load_model_and_data(opt) + + imgs = test_loader.dataset.data + class_names = test_loader.dataset.classes + + inputs, reps, targets = get_representations(opt, model, test_loader, reload = True) + + tSNE(opt, inputs, reps, targets, class_names) # , n_points = 1000) + + # patch_neighbouranalysis(opt, imgs, reps, n_examples = 5, n_neighbours = 8) + + max_activating_patches(opt, imgs, reps, n_neurons = 8, n_max = 10, do_savefig = True) + + embed() + + \ No newline at end of file diff --git a/vision/environment.yml b/vision/environment.yml new file mode 100644 index 0000000..a1d0188 --- /dev/null +++ b/vision/environment.yml @@ -0,0 +1,101 @@ +name: infomax +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - blas=1.0 + - ca-certificates=2018.12.5 + - certifi=2018.11.29 + - cffi=1.11.5 + - cycler=0.10.0 + - dbus=1.13.2 + - expat=2.2.5 + - fontconfig=2.13.1 + - freetype=2.9.1 + - gettext=0.19.8.1 + - glib=2.56.2 + - gst-plugins-base=1.14.0 + - gstreamer=1.14.0 + - icu=58.2 + - intel-openmp=2019.1 + - jpeg=9b + - kiwisolver=1.0.1 + - libedit=3.1.20181209 + - libffi=3.2.1 + - libiconv=1.15 + - libpng=1.6.36 + - libtiff=4.0.10 + - libuuid=2.32.1 + - libxcb=1.13 + - libxml2=2.9.8 + - sox=14.4.2 + - lz4=2.1.6 + - lz4-c=1.8.1.2 + - matplotlib=3.0.2 + - matplotlib-base=3.0.2 + - mkl=2019.1 + - mkl_fft=1.0.10 + - mkl_random=1.0.2 + - ncurses=6.1 + - ninja=1.8.2 + - numpy=1.15.4 + - numpy-base=1.15.4 + - olefile=0.46 + - openssl=1.1.1a + - pandas=0.24.0 + - patsy=0.5.1 + - pcre=8.42 + - pillow=5.4.1 + - pip=18.1 + - pthread-stubs=0.4 + - pycparser=2.19 + - pyparsing=2.3.1 + - pyqt=5.6.0 + - python=3.6.8 + - python-dateutil=2.7.5 + - pytorch=1.0.0 + - pytz=2018.9 + - qt=5.6.3 + - readline=7.0 + - scikit-learn=0.20.2 + - scipy=1.2.0 + - seaborn=0.9.0 + - setuptools=40.6.3 + - sip=4.18.1 + - six=1.12.0 + - sqlite=3.26.0 + - statsmodels=0.9.0 + - torchvision=0.2.1 + - tornado=5.1.1 + - tqdm=4.29.1 + - wget=1.19.5 + - wheel=0.32.3 + - xorg-libxau=1.0.8 + - xorg-libxdmcp=1.1.2 + - xz=5.2.4 + - zlib=1.2.11 + - zstd=1.3.7 + - pip: + + - gdown==3.8.3 + - appdirs==1.4.3 + - attrs==19.1.0 + - black==19.3b0 + - click==7.0 + - decorator==4.4.0 + - docopt==0.6.2 + - imageio==2.5.0 + - jsonpickle==0.9.6 + - munch==2.3.2 + - networkx==2.3 + - protobuf==3.7.1 + - py-cpuinfo==5.0.0 + - pywavelets==1.0.3 + - sacred==0.7.4 + - scikit-image==0.15.0 + - tensorboardx==1.6 + - toml==0.10.0 + - torchaudio -f https://download.pytorch.org/whl/torch_stable.html + - torchviz==0.0.1 + - wrapt==1.11.1 diff --git a/vision/scripts/class_from_inter_layers.sh b/vision/scripts/class_from_inter_layers.sh new file mode 100755 index 0000000..a5b405b --- /dev/null +++ b/vision/scripts/class_from_inter_layers.sh @@ -0,0 +1,15 @@ +# Script for evaluating intermediate layers through linear classification from hidden layers. + +#!/bin/sh + +savepath=./logs/YOURMODEL +n_epochs=299 # 599 +n_in_channels=(128 256 256 512 1024 1024) + +for i in 1 2 3 4 5 6 +do + echo "Testing the model for linear image classification from layer/module $i" + python -m GreedyInfoMax.vision.downstream_classification --model_path $savepath --model_num $n_epochs --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --module_num $i --in_channels ${n_in_channels[$i-1]} + mv $savepath/classification_results.txt $savepath/classification_results_layer_$i.txt + mv $savepath/classification_results_values.npy $savepath/classification_results_values_$i.npy +done \ No newline at end of file diff --git a/vision/scripts/vision_traineval_CLAPP.sh b/vision/scripts/vision_traineval_CLAPP.sh new file mode 100755 index 0000000..62f35b4 --- /dev/null +++ b/vision/scripts/vision_traineval_CLAPP.sh @@ -0,0 +1,11 @@ +# Commands to +# (i) train a model with CLAPP (use 600 epochs because of asynchronous positive and negative updates) +# (ii) evaluate the trained model with linear downstream classification on last layer + +#!/bin/sh + +echo "Training the model on vision data (stl-10)" +python -m GreedyInfoMax.vision.main_vision --download_dataset --save_dir CLAPP --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --contrast_mode 'hinge' --asymmetric_W_pred --num_epochs 600 --negative_samples 1 --sample_negs_locally --sample_negs_locally_same_everywhere --either_pos_or_neg_update + +echo "Testing the model for image classification" +python -m GreedyInfoMax.vision.downstream_classification --model_path ./logs/CLAPP --model_num 599 --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --module_num 6 --asymmetric_W_pred diff --git a/vision/scripts/vision_traineval_CLAPP_s_sym_W_pred.sh b/vision/scripts/vision_traineval_CLAPP_s_sym_W_pred.sh new file mode 100755 index 0000000..9cadefc --- /dev/null +++ b/vision/scripts/vision_traineval_CLAPP_s_sym_W_pred.sh @@ -0,0 +1,11 @@ +# Commands to +# (i) train a model with CLAPP-s, i.e. synchronous positive and negative updates. Negatives are sampled from all locations and transposed W_pred is used (not fully local compared to CLAPP) +# (ii) evaluate the trained model with linear downstream classification on last layer + +#!/bin/sh + +echo "Training the model on vision data (stl-10)" +python -m GreedyInfoMax.vision.main_vision --download_dataset --save_dir CLAPP_s --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --contrast_mode 'hinge' + +echo "Testing the model for image classification" +python -m GreedyInfoMax.vision.downstream_classification --model_path ./logs/CLAPP_s --model_num 299 --encoder_type 'vgg_like' --model_splits 6 --train_module 6 --module_num 6 diff --git a/vision/scripts/vision_traineval_HingeLossCPC.sh b/vision/scripts/vision_traineval_HingeLossCPC.sh new file mode 100755 index 0000000..906566f --- /dev/null +++ b/vision/scripts/vision_traineval_HingeLossCPC.sh @@ -0,0 +1,12 @@ +# Commands to +# (i) train a model with Hinge Loss CPC, the end-to-end version of CLAPP (use 600 epochs because of asynchronous positive and negative updates) +# (ii) evaluate the trained model with linear downstream classification on last layer + +#!/bin/sh + +echo "Training the model on vision data (stl-10)" +python -m GreedyInfoMax.vision.main_vision --download_dataset --save_dir HingeLossCPC --num_epochs 600 --encoder_type 'vgg_like' --model_splits 1 --train_module 1 --contrast_mode 'hinge' --negative_samples 1 --sample_negs_locally --sample_negs_locally_same_everywhere --either_pos_or_neg_update --asymmetric_W_pred + +echo "Testing the model for image classification" +python -m GreedyInfoMax.vision.downstream_classification --model_path ./logs/HingeLossCPC --model_num 599 --encoder_type 'vgg_like' --model_splits 1 --train_module 1 --module_num 1 --asymmetric_W_pred + diff --git a/vision/setup_dependencies.sh b/vision/setup_dependencies.sh new file mode 100644 index 0000000..6ba94fd --- /dev/null +++ b/vision/setup_dependencies.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +echo "Make sure conda is installed." +echo "Installing environment:" +conda env create -f environment.yml || conda env update -f environment.yml || exit +conda activate infomax