Skip to content

Commit

Permalink
add DMRA code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwei committed Oct 12, 2019
1 parent 0ead1f3 commit eec374e
Show file tree
Hide file tree
Showing 10 changed files with 1,087 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
149 changes: 149 additions & 0 deletions dataset_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os

import numpy as np
import PIL.Image
import scipy.io as sio
import torch
from torch.utils import data

class MyData(data.Dataset): # inherit
"""
load data in a folder
"""
mean_rgb = np.array([0.447, 0.407, 0.386])
std_rgb = np.array([0.244, 0.250, 0.253])
def __init__(self, root, transform=False):
super(MyData, self).__init__()
self.root = root

self._transform = transform

img_root = os.path.join(self.root, 'train_images')
lbl_root = os.path.join(self.root, 'train_masks')
depth_root = os.path.join(self.root, 'train_depth')

file_names = os.listdir(img_root)
self.img_names = []
self.lbl_names = []
self.depth_names = []
for i, name in enumerate(file_names):
if not name.endswith('.jpg'):
continue
self.lbl_names.append(
os.path.join(lbl_root, name[:-4]+'.png')
)
self.img_names.append(
os.path.join(img_root, name)
)
self.depth_names.append(
os.path.join(depth_root, name[:-4]+'.png')
)

def __len__(self):
return len(self.img_names)

def __getitem__(self, index):
# load image
img_file = self.img_names[index]
img = PIL.Image.open(img_file)
# img = img.resize((256, 256))
img = np.array(img, dtype=np.uint8)
# load label
lbl_file = self.lbl_names[index]
lbl = PIL.Image.open(lbl_file)
# lbl = lbl.resize((256, 256))
lbl = np.array(lbl, dtype=np.int32)
lbl[lbl != 0] = 1
# load depth
depth_file = self.depth_names[index]
depth = PIL.Image.open(depth_file)
# depth = depth.resize(256, 256)
depth = np.array(depth, dtype=np.uint8)



if self._transform:
return self.transform(img, lbl, depth)
else:
return img, lbl, depth


# Translating numpy_array into format that pytorch can use on Code.
def transform(self, img, lbl, depth):

img = img.astype(np.float64)/255.0
img -= self.mean_rgb
img /= self.std_rgb
img = img.transpose(2, 0, 1) # to verify
img = torch.from_numpy(img).float()
lbl = torch.from_numpy(lbl).long()
depth = depth.astype(np.float64)/255.0
depth = torch.from_numpy(depth).float()
return img, lbl, depth


class MyTestData(data.Dataset):
"""
load data in a folder
"""
mean_rgb = np.array([0.447, 0.407, 0.386])
std_rgb = np.array([0.244, 0.250, 0.253])


def __init__(self, root, transform=False):
super(MyTestData, self).__init__()
self.root = root
self._transform = transform

img_root = os.path.join(self.root, 'test_images')
depth_root = os.path.join(self.root, 'test_depth')
file_names = os.listdir(img_root)
self.img_names = []
self.names = []
self.depth_names = []

for i, name in enumerate(file_names):
if not name.endswith('.jpg'):
continue
self.img_names.append(
os.path.join(img_root, name)
)
self.names.append(name[:-4])
self.depth_names.append(
# os.path.join(depth_root, name[:-4]+'_depth.png') # Test RGBD135 dataset
os.path.join(depth_root, name[:-4] + '.png')
)

def __len__(self):
return len(self.img_names)

def __getitem__(self, index):
# load image
img_file = self.img_names[index]
img = PIL.Image.open(img_file)
img_size = img.size
# img = img.resize((256, 256))
img = np.array(img, dtype=np.uint8)

# load focal
depth_file = self.depth_names[index]
depth = PIL.Image.open(depth_file)
# depth = depth.resize(256, 256)
depth = np.array(depth, dtype=np.uint8)
if self._transform:
img, focal = self.transform(img, depth)
return img, focal, self.names[index], img_size
else:
return img, depth, self.names[index], img_size

def transform(self, img, depth):
img = img.astype(np.float64)/255.0
img -= self.mean_rgb
img /= self.std_rgb
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()

depth = depth.astype(np.float64)/255.0
depth = torch.from_numpy(depth).float()

return img, depth
1 change: 1 addition & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Title: Depth-induced Multi-scale Recurrent Attention Network for Saliency DetectionAuthor: Wei Ji, Jingjing LiE-mail: weiji.dlut@gmail.com"""import torchfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderimport torchvisionimport torch.nn.functional as Fimport torch.optim as optimfrom dataset_loader import MyData, MyTestDatafrom model import RGBNet,DepthNetfrom fusion import ConvLSTMfrom functions import imsaveimport argparsefrom trainer import Trainerimport osconfigurations = { # same configuration as original work # https://github.com/shelhamer/fcn.berkeleyvision.org 1: dict( max_iteration=1000000, lr=1.0e-10, momentum=0.99, weight_decay=0.0005, spshot=20000, nclass=2, sshow=10, )}parser=argparse.ArgumentParser()parser.add_argument('--phase', type=str, default='test', help='train or test')parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters')# parser.add_argument('--train_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/train_data', help='path to train data')parser.add_argument('--train_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/train_data-augment', help='path to train data')parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/DUT-RGBD/test_data', help='path to test data')# parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/NJUD/test_data', help='path to test data')# parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/NLPR/test_data', help='path to test data')# parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/LFSD', help='path to test data')# parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/SSD', help='path to test data')# parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/STEREO', help='path to test data')# parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/RGBD135', help='path to test data') # Need to set dataset_loader.py/line 113parser.add_argument('--snapshot_root', type=str, default='./snapshot', help='path to snapshot')parser.add_argument('--salmap_root', type=str, default='./sal_map', help='path to saliency map')parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys())args = parser.parse_args()cfg = configurations[args.config]cuda = torch.cuda.is_available"""""""""""~~~ dataset loader ~~~"""""""""train_dataRoot = args.train_dataroottest_dataRoot = args.test_datarootif not os.path.exists(args.snapshot_root): os.mkdir(args.snapshot_root)if not os.path.exists(args.salmap_root): os.mkdir(args.salmap_root)if args.phase == 'train': SnapRoot = args.snapshot_root # checkpoint train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True), batch_size=2, shuffle=True, num_workers=4, pin_memory=True)else: MapRoot = args.salmap_root test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True), batch_size=1, shuffle=True, num_workers=4, pin_memory=True)print ('data already')""""""""""" ~~~nets~~~ """""""""start_epoch = 0start_iteration = 0model_rgb = RGBNet(cfg['nclass'])model_depth = DepthNet(cfg['nclass'])model_clstm = ConvLSTM(input_channels=64, hidden_channels=[64, 32, 64], kernel_size=5, step=4, effective_step=[2, 4, 8])if args.param is True: model_rgb.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'snapshot_iter_1000000.pth'))) model_depth.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'depth_snapshot_iter_1000000.pth'))) model_clstm.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'clstm_snapshot_iter_1000000.pth')))else: vgg19_bn = torchvision.models.vgg19_bn(pretrained=True) model_rgb.copy_params_from_vgg19_bn(vgg19_bn) model_depth.copy_params_from_vgg19_bn(vgg19_bn)if cuda: model_rgb = model_rgb.cuda() model_depth = model_depth.cuda() model_clstm = model_clstm.cuda()if args.phase == 'train': # Trainer: class, defined in trainer.py optimizer_rgb = optim.SGD(model_rgb.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) optimizer_depth = optim.SGD(model_depth.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) optimizer_clstm = optim.SGD(model_clstm.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) training = Trainer( cuda=cuda, model_rgb=model_rgb, model_depth=model_depth, model_clstm=model_clstm, optimizer_rgb=optimizer_rgb, optimizer_depth=optimizer_depth, optimizer_clstm=optimizer_clstm, train_loader=train_loader, max_iter=cfg['max_iteration'], snapshot=cfg['spshot'], outpath=args.snapshot_root, sshow=cfg['sshow'] ) training.epoch = start_epoch training.iteration = start_iteration training.train()else: for id, (data, depth, img_name, img_size) in enumerate(test_loader): print('testing bach %d' % (id+1)) inputs = Variable(data).cuda() inputs_depth = Variable(depth).cuda() n, c, h, w = inputs.size() depth = inputs_depth.view(n, h, w, 1).repeat(1, 1, 1, c) depth = depth.transpose(3, 1) depth = depth.transpose(3, 2) h1, h2, h3, h4, h5 = model_rgb(inputs) # RGBNet's output depth_vector, d1, d2, d3, d4, d5 = model_depth(depth) # DepthNet's output outputs_all = model_clstm(depth_vector, h1, h2, h3, h4, h5, d1, d2, d3, d4, d5) # Final output outputs_all = F.softmax(outputs_all, dim=1) outputs = outputs_all[0][1] outputs = outputs.cpu().data.resize_(h, w) imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size) print('The testing process has finished!')
Expand Down
Binary file added figure/dataset.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figure/overall.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.misc import imresize

def imsave(file_name, img, img_size):
"""
save a torch tensor as an image
:param file_name: 'image/folder/image_name'
:param img: 3*h*w torch tensor
:return: nothing
"""
assert(type(img) == torch.FloatTensor,
'img must be a torch.FloatTensor')
ndim = len(img.size())
assert(ndim == 2 or ndim == 3,
'img must be a 2 or 3 dimensional tensor')

img = img.numpy()
img = imresize(img, [img_size[1][0], img_size[0][0]], interp='nearest')
if ndim == 3:
plt.imsave(file_name, np.transpose(img, (1, 2, 0)))
else:
plt.imsave(file_name, img, cmap='gray')
Loading

0 comments on commit eec374e

Please sign in to comment.