diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..0898bc1 Binary files /dev/null and b/.DS_Store differ diff --git a/dataset_loader.py b/dataset_loader.py new file mode 100644 index 0000000..14788df --- /dev/null +++ b/dataset_loader.py @@ -0,0 +1,149 @@ +import os + +import numpy as np +import PIL.Image +import scipy.io as sio +import torch +from torch.utils import data + +class MyData(data.Dataset): # inherit + """ + load data in a folder + """ + mean_rgb = np.array([0.447, 0.407, 0.386]) + std_rgb = np.array([0.244, 0.250, 0.253]) + def __init__(self, root, transform=False): + super(MyData, self).__init__() + self.root = root + + self._transform = transform + + img_root = os.path.join(self.root, 'train_images') + lbl_root = os.path.join(self.root, 'train_masks') + depth_root = os.path.join(self.root, 'train_depth') + + file_names = os.listdir(img_root) + self.img_names = [] + self.lbl_names = [] + self.depth_names = [] + for i, name in enumerate(file_names): + if not name.endswith('.jpg'): + continue + self.lbl_names.append( + os.path.join(lbl_root, name[:-4]+'.png') + ) + self.img_names.append( + os.path.join(img_root, name) + ) + self.depth_names.append( + os.path.join(depth_root, name[:-4]+'.png') + ) + + def __len__(self): + return len(self.img_names) + + def __getitem__(self, index): + # load image + img_file = self.img_names[index] + img = PIL.Image.open(img_file) + # img = img.resize((256, 256)) + img = np.array(img, dtype=np.uint8) + # load label + lbl_file = self.lbl_names[index] + lbl = PIL.Image.open(lbl_file) + # lbl = lbl.resize((256, 256)) + lbl = np.array(lbl, dtype=np.int32) + lbl[lbl != 0] = 1 + # load depth + depth_file = self.depth_names[index] + depth = PIL.Image.open(depth_file) + # depth = depth.resize(256, 256) + depth = np.array(depth, dtype=np.uint8) + + + + if self._transform: + return self.transform(img, lbl, depth) + else: + return img, lbl, depth + + + # Translating numpy_array into format that pytorch can use on Code. + def transform(self, img, lbl, depth): + + img = img.astype(np.float64)/255.0 + img -= self.mean_rgb + img /= self.std_rgb + img = img.transpose(2, 0, 1) # to verify + img = torch.from_numpy(img).float() + lbl = torch.from_numpy(lbl).long() + depth = depth.astype(np.float64)/255.0 + depth = torch.from_numpy(depth).float() + return img, lbl, depth + + +class MyTestData(data.Dataset): + """ + load data in a folder + """ + mean_rgb = np.array([0.447, 0.407, 0.386]) + std_rgb = np.array([0.244, 0.250, 0.253]) + + + def __init__(self, root, transform=False): + super(MyTestData, self).__init__() + self.root = root + self._transform = transform + + img_root = os.path.join(self.root, 'test_images') + depth_root = os.path.join(self.root, 'test_depth') + file_names = os.listdir(img_root) + self.img_names = [] + self.names = [] + self.depth_names = [] + + for i, name in enumerate(file_names): + if not name.endswith('.jpg'): + continue + self.img_names.append( + os.path.join(img_root, name) + ) + self.names.append(name[:-4]) + self.depth_names.append( + # os.path.join(depth_root, name[:-4]+'_depth.png') # Test RGBD135 dataset + os.path.join(depth_root, name[:-4] + '.png') + ) + + def __len__(self): + return len(self.img_names) + + def __getitem__(self, index): + # load image + img_file = self.img_names[index] + img = PIL.Image.open(img_file) + img_size = img.size + # img = img.resize((256, 256)) + img = np.array(img, dtype=np.uint8) + + # load focal + depth_file = self.depth_names[index] + depth = PIL.Image.open(depth_file) + # depth = depth.resize(256, 256) + depth = np.array(depth, dtype=np.uint8) + if self._transform: + img, focal = self.transform(img, depth) + return img, focal, self.names[index], img_size + else: + return img, depth, self.names[index], img_size + + def transform(self, img, depth): + img = img.astype(np.float64)/255.0 + img -= self.mean_rgb + img /= self.std_rgb + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).float() + + depth = depth.astype(np.float64)/255.0 + depth = torch.from_numpy(depth).float() + + return img, depth diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..f92c341 --- /dev/null +++ b/demo.py @@ -0,0 +1 @@ +""" Title: Depth-induced Multi-scale Recurrent Attention Network for Saliency Detection Author: Wei Ji, Jingjing Li E-mail: weiji.dlut@gmail.com """ import torch from torch.autograd import Variable from torch.utils.data import DataLoader import torchvision import torch.nn.functional as F import torch.optim as optim from dataset_loader import MyData, MyTestData from model import RGBNet,DepthNet from fusion import ConvLSTM from functions import imsave import argparse from trainer import Trainer import os configurations = { # same configuration as original work # https://github.com/shelhamer/fcn.berkeleyvision.org 1: dict( max_iteration=1000000, lr=1.0e-10, momentum=0.99, weight_decay=0.0005, spshot=20000, nclass=2, sshow=10, ) } parser=argparse.ArgumentParser() parser.add_argument('--phase', type=str, default='test', help='train or test') parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters') # parser.add_argument('--train_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/train_data', help='path to train data') parser.add_argument('--train_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/train_data-augment', help='path to train data') parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/DUT-RGBD/test_data', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/NJUD/test_data', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/NLPR/test_data', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/LFSD', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/SSD', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/STEREO', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/RGBD135', help='path to test data') # Need to set dataset_loader.py/line 113 parser.add_argument('--snapshot_root', type=str, default='./snapshot', help='path to snapshot') parser.add_argument('--salmap_root', type=str, default='./sal_map', help='path to saliency map') parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys()) args = parser.parse_args() cfg = configurations[args.config] cuda = torch.cuda.is_available """""""""""~~~ dataset loader ~~~""""""""" train_dataRoot = args.train_dataroot test_dataRoot = args.test_dataroot if not os.path.exists(args.snapshot_root): os.mkdir(args.snapshot_root) if not os.path.exists(args.salmap_root): os.mkdir(args.salmap_root) if args.phase == 'train': SnapRoot = args.snapshot_root # checkpoint train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True), batch_size=2, shuffle=True, num_workers=4, pin_memory=True) else: MapRoot = args.salmap_root test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True), batch_size=1, shuffle=True, num_workers=4, pin_memory=True) print ('data already') """"""""""" ~~~nets~~~ """"""""" start_epoch = 0 start_iteration = 0 model_rgb = RGBNet(cfg['nclass']) model_depth = DepthNet(cfg['nclass']) model_clstm = ConvLSTM(input_channels=64, hidden_channels=[64, 32, 64], kernel_size=5, step=4, effective_step=[2, 4, 8]) if args.param is True: model_rgb.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'snapshot_iter_1000000.pth'))) model_depth.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'depth_snapshot_iter_1000000.pth'))) model_clstm.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'clstm_snapshot_iter_1000000.pth'))) else: vgg19_bn = torchvision.models.vgg19_bn(pretrained=True) model_rgb.copy_params_from_vgg19_bn(vgg19_bn) model_depth.copy_params_from_vgg19_bn(vgg19_bn) if cuda: model_rgb = model_rgb.cuda() model_depth = model_depth.cuda() model_clstm = model_clstm.cuda() if args.phase == 'train': # Trainer: class, defined in trainer.py optimizer_rgb = optim.SGD(model_rgb.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) optimizer_depth = optim.SGD(model_depth.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) optimizer_clstm = optim.SGD(model_clstm.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) training = Trainer( cuda=cuda, model_rgb=model_rgb, model_depth=model_depth, model_clstm=model_clstm, optimizer_rgb=optimizer_rgb, optimizer_depth=optimizer_depth, optimizer_clstm=optimizer_clstm, train_loader=train_loader, max_iter=cfg['max_iteration'], snapshot=cfg['spshot'], outpath=args.snapshot_root, sshow=cfg['sshow'] ) training.epoch = start_epoch training.iteration = start_iteration training.train() else: for id, (data, depth, img_name, img_size) in enumerate(test_loader): print('testing bach %d' % (id+1)) inputs = Variable(data).cuda() inputs_depth = Variable(depth).cuda() n, c, h, w = inputs.size() depth = inputs_depth.view(n, h, w, 1).repeat(1, 1, 1, c) depth = depth.transpose(3, 1) depth = depth.transpose(3, 2) h1, h2, h3, h4, h5 = model_rgb(inputs) # RGBNet's output depth_vector, d1, d2, d3, d4, d5 = model_depth(depth) # DepthNet's output outputs_all = model_clstm(depth_vector, h1, h2, h3, h4, h5, d1, d2, d3, d4, d5) # Final output outputs_all = F.softmax(outputs_all, dim=1) outputs = outputs_all[0][1] outputs = outputs.cpu().data.resize_(h, w) imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size) print('The testing process has finished!') \ No newline at end of file diff --git a/figure/dataset.png b/figure/dataset.png new file mode 100644 index 0000000..ecdcf3c Binary files /dev/null and b/figure/dataset.png differ diff --git a/figure/overall.png b/figure/overall.png new file mode 100644 index 0000000..21928ca Binary files /dev/null and b/figure/overall.png differ diff --git a/functions.py b/functions.py new file mode 100644 index 0000000..3206c35 --- /dev/null +++ b/functions.py @@ -0,0 +1,24 @@ +import numpy as np +import matplotlib.pyplot as plt +import torch +from scipy.misc import imresize + +def imsave(file_name, img, img_size): + """ + save a torch tensor as an image + :param file_name: 'image/folder/image_name' + :param img: 3*h*w torch tensor + :return: nothing + """ + assert(type(img) == torch.FloatTensor, + 'img must be a torch.FloatTensor') + ndim = len(img.size()) + assert(ndim == 2 or ndim == 3, + 'img must be a 2 or 3 dimensional tensor') + + img = img.numpy() + img = imresize(img, [img_size[1][0], img_size[0][0]], interp='nearest') + if ndim == 3: + plt.imsave(file_name, np.transpose(img, (1, 2, 0))) + else: + plt.imsave(file_name, img, cmap='gray') diff --git a/fusion.py b/fusion.py new file mode 100644 index 0000000..6917bb8 --- /dev/null +++ b/fusion.py @@ -0,0 +1,336 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F + +''' +fusion: consits of DRB, DMSW, RAM. +''' + +class ConvLSTMCell(nn.Module): + def __init__(self, input_channels, hidden_channels, kernel_size, bias=True): + super(ConvLSTMCell, self).__init__() + + assert hidden_channels % 2 == 0 + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.bias = bias + self.kernel_size = kernel_size + self.num_features = 4 + + self.padding = (kernel_size - 1) //2 + self.conv = nn.Conv2d(self.input_channels + self.hidden_channels, 4 * self.hidden_channels, self.kernel_size, 1, + self.padding) + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal(m.weight.data, std=0.01) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, input, h, c): + + combined = torch.cat((input, h), dim=1) + A = self.conv(combined) + (ai, af, ao, ag) = torch.split(A, A.size()[1] // self.num_features, dim=1) + i = torch.sigmoid(ai) #input gate + f = torch.sigmoid(af) #forget gate + o = torch.sigmoid(ao) #output + g = torch.tanh(ag) #update_Cell + + new_c = f * c + i * g + new_h = o * torch.tanh(new_c) + return new_h, new_c, o + + @staticmethod + def init_hidden(batch_size, hidden_c, shape): + return (Variable(torch.zeros(batch_size, hidden_c, shape[0], shape[1])).cuda(), + Variable(torch.zeros(batch_size, hidden_c, shape[0], shape[1])).cuda()) + + +class ConvLSTM(nn.Module): + def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1], bias=True): + super(ConvLSTM, self).__init__() + self.input_channels = [input_channels] + hidden_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.num_layers = len(hidden_channels) + self.step = step + self.bias = bias + self.effective_step = effective_step + self._all_layers = [] + for i in range(self.num_layers): + name = 'cell{}'.format(i) + cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size, self.bias) + setattr(self, name, cell) + self._all_layers.append(cell) + + + + # --------------------------- Depth Refinement Block -------------------------- # + # DRB 1 + self.conv_refine1_1 = nn.Conv2d(64, 64, 3, padding=1) + self.bn_refine1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine1_1 = nn.PReLU() + self.conv_refine1_2 = nn.Conv2d(64, 64, 3, padding=1) + self.bn_refine1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine1_2 = nn.PReLU() + self.conv_refine1_3 = nn.Conv2d(64, 64, 3, padding=1) + self.bn_refine1_3 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine1_3 = nn.PReLU() + self.down_2_1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.down_2_2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + # DRB 2 + self.conv_refine2_1 = nn.Conv2d(128, 128, 3, padding=1) + self.bn_refine2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine2_1 = nn.PReLU() + self.conv_refine2_2 = nn.Conv2d(128, 128, 3, padding=1) + self.bn_refine2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine2_2 = nn.PReLU() + self.conv_refine2_3 = nn.Conv2d(128, 128, 3, padding=1) + self.bn_refine2_3 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine2_3 = nn.PReLU() + self.conv_r2_1 = nn.Conv2d(128, 64, 3, padding=1) + self.bn_r2_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu_r2_1 = nn.PReLU() + # DRB 3 + self.conv_refine3_1 = nn.Conv2d(256, 256, 3, padding=1) + self.bn_refine3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine3_1 = nn.PReLU() + self.conv_refine3_2 = nn.Conv2d(256, 256, 3, padding=1) + self.bn_refine3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine3_2 = nn.PReLU() + self.conv_refine3_3 = nn.Conv2d(256, 256, 3, padding=1) + self.bn_refine3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine3_3 = nn.PReLU() + self.conv_r3_1 = nn.Conv2d(256, 64, 3, padding=1) + self.bn_r3_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu_r3_1 = nn.PReLU() + # DRB 4 + self.conv_refine4_1 = nn.Conv2d(512, 512, 3, padding=1) + self.bn_refine4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine4_1 = nn.PReLU() + self.conv_refine4_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn_refine4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine4_2 = nn.PReLU() + self.conv_refine4_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn_refine4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine4_3 = nn.PReLU() + self.conv_r4_1 = nn.Conv2d(512, 64, 3, padding=1) + self.bn_r4_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu_r4_1 = nn.PReLU() + # DRB 5 + self.conv_refine5_1 = nn.Conv2d(512, 512, 3, padding=1) + self.bn_refine5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine5_1 = nn.PReLU() + self.conv_refine5_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn_refine5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine5_2 = nn.PReLU() + self.conv_refine5_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn_refine5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu_refine5_3 = nn.PReLU() + self.conv_r5_1 = nn.Conv2d(512, 64, 3, padding=1) + self.bn_r5_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu_r5_1 = nn.PReLU() + + + # ----------------------------- Multi-scale ----------------------------- # + # Add new structure: ASPP Atrous spatial Pyramid Pooling based on DeepLab v3 + # part0: 1*1*64 Conv + self.conv5_conv_1 = nn.Conv2d(64, 64, 1, padding=0) # size: 64*64*64 + self.bn5_conv_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu5_conv_1 = nn.ReLU(inplace=True) + # part1: 3*3*64 Conv + self.conv5_conv = nn.Conv2d(64, 64, 3, padding=1) # size: 64*64*64 + self.bn5_conv = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu5_conv = nn.ReLU(inplace=True) + # part2: 3*3*64 (dilated=7) Atrous Conv + self.Atrous_conv_1 = nn.Conv2d(64, 64, 3, padding=7, dilation=7) # size: 64*64*64 + self.Atrous_bn5_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.Atrous_relu_1 = nn.ReLU(inplace=True) + # part3: 3*3*64 (dilated=5) Atrous Conv + self.Atrous_conv_2 = nn.Conv2d(64, 64, 3, padding=5, dilation=5) # size: 64*64*64 + self.Atrous_bn5_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.Atrous_relu_2 = nn.ReLU(inplace=True) + # part4: 3*3*64 (dilated=3) Atrous Conv + self.Atrous_conv_5 = nn.Conv2d(64, 64, 3, padding=3, dilation=3) # size: 64*64*64 + self.Atrous_bn5_5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.Atrous_relu_5 = nn.ReLU(inplace=True) + # part5: Max_pooling # size: 16*16*64 + self.Atrous_pooling = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.Atrous_conv_pool = nn.Conv2d(64, 64, 1, padding=0) + self.Atrous_bn_pool = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.Atrous_relu_pool = nn.ReLU(inplace=True) + + + + # ----------------------------- Channel-wise Attention ----------------------------- # + self.conv_c = nn.Conv2d(64, 64, 3, padding=1) + self.conv_h = nn.Conv2d(64, 64, 3, padding=1) + self.pool_avg = nn.AvgPool2d(64, stride=2, ceil_mode=True) # 1/8 + + + + # ----------------------------- Sptatial-wise Attention ----------------------------- # + self.conv_s1 = nn.Conv2d(64 * self.num_layers, 64, 1, padding=0) + self.conv_s2 = nn.Conv2d(64 * self.num_layers, 1, 1, padding=0) + + + # ----------------------------- Prediction ----------------------------- # + self.conv_pred = nn.Conv2d(64, 2, 1, padding=0) + + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal(m.weight.data, std=0.01) + if m.bias is not None: + m.bias.data.zero_() + + + def forward(self,depth_vector,h1,h2,h3,h4,h5,d1,d2,d3,d4,d5): + internal_state = [] + + + # -------- apply DRB --------- # + # drb 1 + d1_1 = self.relu_refine1_1(self.bn_refine1_1(self.conv_refine1_1(d1))) + d1_2 = self.relu_refine1_2(self.bn_refine1_2(self.conv_refine1_2(d1_1))) + d1_2 = d1_2 + h1 # (256x256)*64 + d1_2 = self.down_2_2(self.down_2_1(d1_2)) + d1_2_0 = d1_2 + d1_3 = self.relu_refine1_3(self.bn_refine1_3(self.conv_refine1_3(d1_2))) + drb1 = d1_2_0 + d1_3 # (64 x 64)*64 + + # drb 2 + d2_1 = self.relu_refine2_1(self.bn_refine2_1(self.conv_refine2_1(d2))) + d2_2 = self.relu_refine2_2(self.bn_refine2_2(self.conv_refine2_2(d2_1))) + d2_2 = d2_2 + h2 # (128x128)*128 + d2_2 = self.down_2_1(d2_2) + d2_2_0 = d2_2 + d2_3 = self.relu_refine2_3(self.bn_refine2_3(self.conv_refine2_3(d2_2))) + drb2 = d2_2_0 + d2_3 + drb2 = self.relu_r2_1(self.bn_r2_1(self.conv_r2_1(drb2))) # (64 x 64)*64 + + # drb 3 + d3_1 = self.relu_refine3_1(self.bn_refine3_1(self.conv_refine3_1(d3))) + d3_2 = self.relu_refine3_2(self.bn_refine3_2(self.conv_refine3_2(d3_1))) + d3_2 = d3_2 + h3 # (64 x 64)*256 + d3_2_0 = d3_2 + d3_3 = self.relu_refine3_3(self.bn_refine3_3(self.conv_refine3_3(d3_2))) + drb3 = d3_2_0 + d3_3 + drb3 = self.relu_r3_1(self.bn_r3_1(self.conv_r3_1(drb3))) # (64 x 64)*64 + + # drb 4 + d4_1 = self.relu_refine4_1(self.bn_refine4_1(self.conv_refine4_1(d4))) + d4_2 = self.relu_refine4_2(self.bn_refine4_2(self.conv_refine4_2(d4_1))) + d4_2 = d4_2 + h4 # (32 x 32)*512 + d4_2 = F.upsample(d4_2, scale_factor=2, mode='bilinear') + d4_2_0 = d4_2 + d4_3 = self.relu_refine4_3(self.bn_refine4_3(self.conv_refine4_3(d4_2))) + drb4 = d4_2_0 + d4_3 + drb4 = self.relu_r4_1(self.bn_r4_1(self.conv_r4_1(drb4))) # (64 x 64)*64 + + # drb 5 + d5_1 = self.relu_refine5_1(self.bn_refine5_1(self.conv_refine5_1(d5))) + d5_2 = self.relu_refine5_2(self.bn_refine5_2(self.conv_refine5_2(d5_1))) + d5_2 = d5_2 + h5 # (16 x 16)*64 + d5_2 = F.upsample(d5_2, scale_factor=4, mode='bilinear') + d5_2_0 = d5_2 + d5_3 = self.relu_refine5_3(self.bn_refine5_3(self.conv_refine5_3(d5_2))) + drb5 = d5_2_0 + d5_3 + drb5 = self.relu_r5_1(self.bn_r5_1(self.conv_r5_1(drb5))) # (64 x 64)*64 + + drb_fusion = drb1 +drb2 + drb3 +drb4 +drb5 # (64 x 64)*64 + + + # --------------------- obtain multi-scale ----------------------- # + f1 = self.relu5_conv_1(self.bn5_conv_1(self.conv5_conv_1(drb_fusion))) + f2 = self.relu5_conv(self.bn5_conv(self.conv5_conv(drb_fusion))) + f3 = self.Atrous_relu_1(self.Atrous_bn5_1(self.Atrous_conv_1(drb_fusion))) + f4 = self.Atrous_relu_2(self.Atrous_bn5_2(self.Atrous_conv_2(drb_fusion))) + f5 = self.Atrous_relu_5(self.Atrous_bn5_5(self.Atrous_conv_5(drb_fusion))) + f6 = F.upsample( + self.Atrous_relu_pool(self.Atrous_bn_pool(self.Atrous_conv_pool(self.Atrous_pooling(self.Atrous_pooling(drb_fusion))))), + scale_factor=4, mode='bilinear') + + + + + fusion = torch.cat([f1,f2,f3,f4,f5,f6],dim=0) # 6x64x64x64 + fusion_o = fusion + input = torch.cat(torch.chunk(fusion, 6, dim=0), dim=1) + + + + + for step in range(self.step): + depth = depth_vector # 1x 6 x 1 x1 + + if step == 0: + basize, _, height, width = input.size() + (h_step, c) = ConvLSTMCell.init_hidden(basize, self.hidden_channels[self.num_layers-1],(height, width)) + + + # Feature-wise Attention + depth = torch.mul(F.softmax(depth,dim=1), 6) + + basize, dime, h, w = depth.size() + + depth = depth.view(1, basize, dime, h, w).transpose(0,1).transpose(1,2) + depth = torch.cat(torch.chunk(depth, basize, dim=0), dim=1).view(basize*dime, 1, 1, 1) + + depth = torch.mul(fusion_o, depth).view(1, basize*dime, 64, 64, 64) + depth = torch.cat(torch.chunk(depth, basize, dim=1), dim=0) + F_sum = torch.sum(depth, 1, keepdim=False)#.squeeze() + + + # Channel-wise Attention + depth_fw_ori = F_sum + depth = self.conv_c(F_sum) + h_c = self.conv_h(h_step) + depth = depth + h_c + depth = self.pool_avg(depth) + depth = torch.mul(F.softmax(depth, dim=1), 64) + F_sum_wt = torch.mul(depth_fw_ori, depth) + + + + x = F_sum_wt + if step < self.step-1: + for i in range(self.num_layers): + # all cells are initialized in the first step + if step == 0: + bsize, _, height, width = x.size() + (h, c) = ConvLSTMCell.init_hidden(bsize, self.hidden_channels[i], (height, width)) + internal_state.append((h, c)) + # do forward + name = 'cell{}'.format(i) + (h, c) = internal_state[i] + h_step = h + + x, new_c, new_o = getattr(self, name)(x, h, c) # ConvLSTMCell forward + internal_state[i] = (x, new_c) + + # only record effective steps + #if step in self.effective_step: + + if step == 0: + outputs_o = new_o + else: + outputs_o = torch.cat((outputs_o, new_o), dim=1) + + # ---------------> Spatial-wise Attention Module <----------------- # + outputs = self.conv_s1(outputs_o) + spatial_weight = F.sigmoid(self.conv_s2(outputs_o)) + outputs = torch.mul(outputs,spatial_weight) + # -------------------------> Prediction <-------------------------- # + outputs = self.conv_pred(outputs) + output = F.upsample(outputs, scale_factor=4, mode='bilinear') + + return output + diff --git a/model.py b/model.py new file mode 100644 index 0000000..d7e10db --- /dev/null +++ b/model.py @@ -0,0 +1,364 @@ +import torch +import torch.nn as nn +import numpy as np + + +def get_upsampling_weight(in_channels, out_channels, kernel_size): + """Make a 2D bilinear kernel suitable for upsampling""" + factor = (kernel_size + 1) // 2 + if kernel_size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:kernel_size, :kernel_size] + filt = (1 - abs(og[0] - center) / factor) * \ + (1 - abs(og[1] - center) / factor) + weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), + dtype=np.float64) + weight[range(in_channels), range(out_channels), :, :] = filt + return torch.from_numpy(weight).float() + + + + + +#################################### RGB Network ##################################### +class RGBNet(nn.Module): + def __init__(self,n_class=2): + super(RGBNet, self).__init__() + + # original image's size = 256*256*3 + + # conv1 + self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) + self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu1_1 = nn.ReLU(inplace=True) + self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) + self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu1_2 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers + + # conv2 + self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) + self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) + self.relu2_1 = nn.ReLU(inplace=True) + self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) + self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) + self.relu2_2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers + + # conv3 + self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) + self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_1 = nn.ReLU(inplace=True) + self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_2 = nn.ReLU(inplace=True) + self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_3 = nn.ReLU(inplace=True) + self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_4 = nn.ReLU(inplace=True) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers + + # conv4 + self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) + self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_1 = nn.ReLU(inplace=True) + self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_2 = nn.ReLU(inplace=True) + self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_3 = nn.ReLU(inplace=True) + self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_4 = nn.ReLU(inplace=True) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers + + # conv5 + self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_1 = nn.ReLU(inplace=True) + self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_2 = nn.ReLU(inplace=True) + self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_3 = nn.ReLU(inplace=True) + self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_4 = nn.ReLU(inplace=True) # 1/32 4 layers + + self._initialize_weights() + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # m.weight.data.zero_() + nn.init.normal(m.weight.data, std=0.01) + if m.bias is not None: + m.bias.data.zero_() + if isinstance(m, nn.ConvTranspose2d): + assert m.kernel_size[0] == m.kernel_size[1] + initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) + m.weight.data.copy_(initial_weight) + + + + def forward(self, x): + h = x + + h = self.relu1_1(self.bn1_1(self.conv1_1(h))) + h = self.relu1_2(self.bn1_2(self.conv1_2(h))) + h_nopool1 = h + h = self.pool1(h) + h1 = h_nopool1 # (256x256)*64 + + h = self.relu2_1(self.bn2_1(self.conv2_1(h))) + h = self.relu2_2(self.bn2_2(self.conv2_2(h))) + h_nopool2 = h + h = self.pool2(h) + h2 = h_nopool2 # (128x128)*128 + + h = self.relu3_1(self.bn3_1(self.conv3_1(h))) + h = self.relu3_2(self.bn3_2(self.conv3_2(h))) + h = self.relu3_3(self.bn3_3(self.conv3_3(h))) + h = self.relu3_4(self.bn3_4(self.conv3_4(h))) + h_nopool3 = h + h = self.pool3(h) + h3 = h_nopool3 # (64x64)*256 + + h = self.relu4_1(self.bn4_1(self.conv4_1(h))) + h = self.relu4_2(self.bn4_2(self.conv4_2(h))) + h = self.relu4_3(self.bn4_3(self.conv4_3(h))) + h = self.relu4_4(self.bn4_4(self.conv4_4(h))) + h_nopool4 = h + h = self.pool4(h) + h4 = h_nopool4 # (32x32)*512 + + h = self.relu5_1(self.bn5_1(self.conv5_1(h))) + h = self.relu5_2(self.bn5_2(self.conv5_2(h))) + h = self.relu5_3(self.bn5_3(self.conv5_3(h))) + h = self.relu5_4(self.bn5_4(self.conv5_4(h))) + h5 = h # (16x16)*512 + + + return h1,h2,h3,h4,h5 + + + + def copy_params_from_vgg19_bn(self, vgg19_bn): + features = [ + self.conv1_1, self.bn1_1, self.relu1_1, + self.conv1_2, self.bn1_2, self.relu1_2, + self.pool1, + self.conv2_1, self.bn2_1, self.relu2_1, + self.conv2_2, self.bn2_2, self.relu2_2, + self.pool2, + self.conv3_1, self.bn3_1, self.relu3_1, + self.conv3_2, self.bn3_2, self.relu3_2, + self.conv3_3, self.bn3_3, self.relu3_3, + self.conv3_4, self.bn3_4, self.relu3_4, + self.pool3, + self.conv4_1, self.bn4_1, self.relu4_1, + self.conv4_2, self.bn4_2, self.relu4_2, + self.conv4_3, self.bn4_3, self.relu4_3, + self.conv4_4, self.bn4_4, self.relu4_4, + self.pool4, + self.conv5_1, self.bn5_1, self.relu5_1, + self.conv5_2, self.bn5_2, self.relu5_2, + self.conv5_3, self.bn5_3, self.relu5_3, + self.conv5_4, self.bn5_4, self.relu5_4, + ] + for l1, l2 in zip(vgg19_bn.features, features): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + +#################################### Depth Network ##################################### +class DepthNet(nn.Module): + def __init__(self, n_class=2): + super(DepthNet, self).__init__() + + # original image's size = 256*256*3 + + # conv1 + self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) + self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu1_1 = nn.ReLU(inplace=True) + self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) + self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) + self.relu1_2 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers + + # conv2 + self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) + self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) + self.relu2_1 = nn.ReLU(inplace=True) + self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) + self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) + self.relu2_2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers + + # conv3 + self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) + self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_1 = nn.ReLU(inplace=True) + self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_2 = nn.ReLU(inplace=True) + self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_3 = nn.ReLU(inplace=True) + self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) + self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) + self.relu3_4 = nn.ReLU(inplace=True) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers + + # conv4 + self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) + self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_1 = nn.ReLU(inplace=True) + self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_2 = nn.ReLU(inplace=True) + self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_3 = nn.ReLU(inplace=True) + self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) + self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu4_4 = nn.ReLU(inplace=True) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers + + # conv5 + self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_1 = nn.ReLU(inplace=True) + self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_2 = nn.ReLU(inplace=True) + self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_3 = nn.ReLU(inplace=True) + self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) + self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) + self.relu5_4 = nn.ReLU(inplace=True) # 1/32 4 layers + + # depth vector + self.conv_fcn2 = nn.Conv2d(512, 64, 3, padding=1) + self.pool_avg = nn.AvgPool2d(16, stride=2, ceil_mode=True) + self.conv_c = nn.Conv2d(64, 6, 1, padding=0) + + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # m.weight.data.zero_() + nn.init.normal(m.weight.data, std=0.01) + if m.bias is not None: + m.bias.data.zero_() + if isinstance(m, nn.ConvTranspose2d): + assert m.kernel_size[0] == m.kernel_size[1] + initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) + m.weight.data.copy_(initial_weight) + + def forward(self, x): + h = x + + h = self.relu1_1(self.bn1_1(self.conv1_1(h))) + h = self.relu1_2(self.bn1_2(self.conv1_2(h))) + h_nopool1 = h + h = self.pool1(h) + d1 = h_nopool1 # (256x256)*64 + + h = self.relu2_1(self.bn2_1(self.conv2_1(h))) + h = self.relu2_2(self.bn2_2(self.conv2_2(h))) + h_nopool2 = h + h = self.pool2(h) + d2 = h_nopool2 # (128x128)*128 + + h = self.relu3_1(self.bn3_1(self.conv3_1(h))) + h = self.relu3_2(self.bn3_2(self.conv3_2(h))) + h = self.relu3_3(self.bn3_3(self.conv3_3(h))) + h = self.relu3_4(self.bn3_4(self.conv3_4(h))) + h_nopool3 = h + h = self.pool3(h) + d3 = h_nopool3 # (64x64)*256 + + h = self.relu4_1(self.bn4_1(self.conv4_1(h))) + h = self.relu4_2(self.bn4_2(self.conv4_2(h))) + h = self.relu4_3(self.bn4_3(self.conv4_3(h))) + h = self.relu4_4(self.bn4_4(self.conv4_4(h))) + h_nopool4 = h + h = self.pool4(h) + d4 = h_nopool4 # (32x32)*512 + + h = self.relu5_1(self.bn5_1(self.conv5_1(h))) + h = self.relu5_2(self.bn5_2(self.conv5_2(h))) + h = self.relu5_3(self.bn5_3(self.conv5_3(h))) + h = self.relu5_4(self.bn5_4(self.conv5_4(h))) + d5 = h # (16x16)*512 + + # depth vector + vector = self.conv_fcn2(d5) + vector = self.pool_avg(vector) + depth_vector = self.conv_c(vector) + + + + return depth_vector, d1, d2, d3, d4, d5 + + def copy_params_from_vgg19_bn(self, vgg19_bn): + features = [ + self.conv1_1, self.bn1_1, self.relu1_1, + self.conv1_2, self.bn1_2, self.relu1_2, + self.pool1, + self.conv2_1, self.bn2_1, self.relu2_1, + self.conv2_2, self.bn2_2, self.relu2_2, + self.pool2, + self.conv3_1, self.bn3_1, self.relu3_1, + self.conv3_2, self.bn3_2, self.relu3_2, + self.conv3_3, self.bn3_3, self.relu3_3, + self.conv3_4, self.bn3_4, self.relu3_4, + self.pool3, + self.conv4_1, self.bn4_1, self.relu4_1, + self.conv4_2, self.bn4_2, self.relu4_2, + self.conv4_3, self.bn4_3, self.relu4_3, + self.conv4_4, self.bn4_4, self.relu4_4, + self.pool4, + self.conv5_1, self.bn5_1, self.relu5_1, + self.conv5_2, self.bn5_2, self.relu5_2, + self.conv5_3, self.bn5_3, self.relu5_3, + self.conv5_4, self.bn5_4, self.relu5_4, + ] + for l1, l2 in zip(vgg19_bn.features, features): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + + + + + + + diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..31f1795 --- /dev/null +++ b/trainer.py @@ -0,0 +1,156 @@ +import math + +from torch.autograd import Variable +import torch.nn.functional as F +import torch + + + +running_loss_final = 0 + + + +def cross_entropy2d(input, target, weight=None, size_average=True): + n, c, h, w = input.size() + + input = input.transpose(1,2).transpose(2,3).contiguous() + input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] + input = input.view(-1, c) + # target: (n*h*w,) + mask = target >= 0 + target = target[mask] + loss = F.cross_entropy(input, target, weight=weight, size_average=False) + if size_average: + loss /= mask.data.sum() + return loss + + + + +class Trainer(object): + + def __init__(self, cuda, model_rgb,model_depth,model_clstm, optimizer_rgb, + optimizer_depth,optimizer_clstm, + train_loader, max_iter, snapshot, outpath, sshow, size_average=False): + self.cuda = cuda + self.model_rgb = model_rgb + self.model_depth = model_depth + self.model_clstm = model_clstm + self.optim_rgb = optimizer_rgb + self.optim_depth = optimizer_depth + self.optim_clstm = optimizer_clstm + self.train_loader = train_loader + self.epoch = 0 + self.iteration = 0 + self.max_iter = max_iter + self.snapshot = snapshot + self.outpath = outpath + self.sshow = sshow + self.size_average = size_average + + + + def train_epoch(self): + + for batch_idx, (data, target, depth) in enumerate(self.train_loader): + + + iteration = batch_idx + self.epoch * len(self.train_loader) + if self.iteration != 0 and (iteration - 1) != self.iteration: + continue # for resuming + self.iteration = iteration + if self.iteration >= self.max_iter: + break + if self.cuda: + data, target, depth = data.cuda(), target.cuda(), depth.cuda() + data, target, depth = Variable(data), Variable(target), Variable(depth) + n, c, h, w = data.size() # batch_size, channels, height, weight + depth = depth.view(n,h,w,1).repeat(1,1,1,c) + depth = depth.transpose(3,1) + depth = depth.transpose(3,2) + + + self.optim_rgb.zero_grad() + self.optim_depth.zero_grad() + self.optim_clstm.zero_grad() + + global running_loss_final + + + h1,h2,h3,h4,h5 = self.model_rgb(data) # RGBNet's output + depth_vector,d1,d2,d3,d4,d5 = self.model_depth(depth) # DepthNet's output + + # ------------------------------ Fusion --------------------------- # + score_fusion = self.model_clstm(depth_vector,h1,h2,h3,h4,h5,d1,d2,d3,d4,d5) # Final output + loss_all = cross_entropy2d(score_fusion, target, size_average=self.size_average) + + + + running_loss_final += loss_all.data[0] + + + if iteration % self.sshow == (self.sshow-1): + print('\n [%3d, %6d, The training loss of DMRA_Net: %.3f]' % (self.epoch + 1, iteration + 1, running_loss_final / (n * self.sshow))) + + running_loss_final = 0.0 + + + if iteration <= 200000: + if iteration % self.snapshot == (self.snapshot-1): + savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration+1)) + torch.save(self.model_rgb.state_dict(), savename) + print('save: (snapshot: %d)' % (iteration+1)) + + savename_focal = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) + torch.save(self.model_depth.state_dict(), savename_focal) + print('save: (snapshot_depth: %d)' % (iteration+1)) + + savename_clstm = ('%s/clstm_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) + torch.save(self.model_clstm.state_dict(), savename_clstm) + print('save: (snapshot_clstm: %d)' % (iteration+1)) + + else: + if iteration % 10000 == (10000 - 1): + savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) + torch.save(self.model_rgb.state_dict(), savename) + print('save: (snapshot: %d)' % (iteration + 1)) + + savename_focal = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) + torch.save(self.model_depth.state_dict(), savename_focal) + print('save: (snapshot_depth: %d)' % (iteration + 1)) + + savename_clstm = ('%s/clstm_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) + torch.save(self.model_clstm.state_dict(), savename_clstm) + print('save: (snapshot_clstm: %d)' % (iteration + 1)) + + + + if (iteration+1) == self.max_iter: + savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration+1)) + torch.save(self.model_rgb.state_dict(), savename) + print('save: (snapshot: %d)' % (iteration+1)) + + savename_focal = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) + torch.save(self.model_depth.state_dict(), savename_focal) + print('save: (snapshot_depth: %d)' % (iteration+1)) + + savename_clstm = ('%s/clstm_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) + torch.save(self.model_clstm.state_dict(), savename_clstm) + print('save: (snapshot_clstm: %d)' % (iteration+1)) + + + + + loss_all.backward() + self.optim_clstm.step() + self.optim_depth.step() + self.optim_rgb.step() + + def train(self): + max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader))) + + for epoch in range(max_epoch): + self.epoch = epoch + self.train_epoch() + if self.iteration >= self.max_iter: + break diff --git a/transform.py b/transform.py new file mode 100644 index 0000000..4d29979 --- /dev/null +++ b/transform.py @@ -0,0 +1,57 @@ +import numpy as np +import torch + +from PIL import Image + +def colormap(n): #import n, then r'g'b obtain values, finally acquiring colormap + cmap=np.zeros([n, 3]).astype(np.uint8) + + for i in np.arange(n): + r, g, b = np.zeros(3) + + for j in np.arange(8): + r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) + g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) + b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) + + cmap[i,:] = np.array([r, g, b]) + + return cmap + +class Relabel: + + def __init__(self, olabel, nlabel): + self.olabel = olabel + self.nlabel = nlabel + + def __call__(self, tensor): + assert isinstance(tensor, torch.LongTensor), 'tensor needs to be LongTensor' + tensor[tensor == self.olabel] = self.nlabel + return tensor + + +class ToLabel: + + def __call__(self, image): + return torch.from_numpy(np.array(image)).long().unsqueeze(0) + + +class Colorize: + + def __init__(self, n=21): + self.cmap = colormap(256) + self.cmap[n] = self.cmap[-1] + self.cmap = torch.from_numpy(self.cmap[:n]) + + def __call__(self, gray_image): + size = gray_image.size() + color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) + + for label in range(1, len(self.cmap)): + mask = (gray_image == label) + + color_image[0][mask] = self.cmap[label][0] + color_image[1][mask] = self.cmap[label][1] + color_image[2][mask] = self.cmap[label][2] + + return color_image