diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml
new file mode 100644
index 0000000..48ac639
--- /dev/null
+++ b/.github/workflows/deploy.yaml
@@ -0,0 +1,27 @@
+name: GitHub Pages
+
+on:
+ push:
+ branches:
+ - master
+
+jobs:
+ deploy:
+ runs-on: ubuntu-20.04
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Setup Node
+ uses: actions/setup-node@v2.1.2
+
+ - name: Build
+ run: |
+ cd website
+ yarn
+ yarn build
+
+ - name: deploy
+ uses: peaceiris/actions-gh-pages@v3
+ with:
+ github_token: ${{ secrets.TOKEN }}
+ publish_dir: ./website/build
diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml
new file mode 100644
index 0000000..fcf7382
--- /dev/null
+++ b/.github/workflows/push.yaml
@@ -0,0 +1,28 @@
+name: Push to Master Repo
+
+on:
+ push:
+ branches:
+ - master
+ workflow_dispatch:
+
+jobs:
+ deploy:
+ runs-on: ubuntu-20.04
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ ref: 'master'
+ - name: Install SSH key
+ uses: shimataro/ssh-key-action@v2
+ with:
+ key: ${{ secrets.SSH_KEY }}
+ known_hosts: ${{ secrets.KNOWN_HOSTS }}
+ - run: |
+ git remote add official git@github.com:KentoNishi/JTR-CVPR-2024.git
+ git config --global user.email ${{ secrets.EMAIL }}
+ git config --global user.name ${{ secrets.USERNAME }}
+ git checkout master
+ # squash all commits into one
+ git reset $(git commit-tree HEAD^{tree} -m "Update at $(date +'%Y-%m-%d %H:%M:%S')")
+ git push -u official master -f
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3d66e44
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,162 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+# lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+ignore
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..3becc19
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Kento Nishi
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..35a1c13
--- /dev/null
+++ b/README.md
@@ -0,0 +1,31 @@
+# JTR-CVPR-2024
+
+
+
+
+
+
+## Updates
+* **May 2024:** Code released for NYUv2 `onelabel` and `randomlabels`.
+* **May 2024:** Website updated with the CVPR poster and video.
+* **April 2024:** Paper website published at [kentonishi.com/JTR-CVPR-2024](https://kentonishi.com/JTR-CVPR-2024).
+
+
+## Training
+```bash
+# NYUv2 onelabel
+python train_nyuv2.py \
+ --data-dir [/some/data/dir] \
+ --out-dir [/some/output/dir/nyuv2_onelabel] \
+ --ssl-type onelabel \
+ --label-dir ./data/nyuv2_settings \
+ --seg-baseline 25.75 --depth-baseline 0.6511 --norm-baseline 33.73
+
+# NYUv2 randomlabels
+python train_nyuv2.py \
+ --data-dir [/some/data/dir] \
+ --out-dir [/some/output/dir/nyuv2_randomlabels] \
+ --ssl-type randomlabels \
+ --label-dir ./data/nyuv2_settings \
+ --seg-baseline 27.05 --depth-baseline 0.6626 --norm-baseline 33.58
+```
diff --git a/code/data/nyuv2_settings/onelabel.pth b/code/data/nyuv2_settings/onelabel.pth
new file mode 100644
index 0000000..c629fdf
Binary files /dev/null and b/code/data/nyuv2_settings/onelabel.pth differ
diff --git a/code/data/nyuv2_settings/randomlabels.pth b/code/data/nyuv2_settings/randomlabels.pth
new file mode 100644
index 0000000..9bf574d
Binary files /dev/null and b/code/data/nyuv2_settings/randomlabels.pth differ
diff --git a/code/data/pascal-context b/code/data/pascal-context
new file mode 120000
index 0000000..352e40d
--- /dev/null
+++ b/code/data/pascal-context
@@ -0,0 +1 @@
+/n/home11/jskim/vcg_natural/multi-task/pascal-context
\ No newline at end of file
diff --git a/code/datasets/nyuv2.py b/code/datasets/nyuv2.py
new file mode 100644
index 0000000..142e065
--- /dev/null
+++ b/code/datasets/nyuv2.py
@@ -0,0 +1,148 @@
+from torch.utils.data.dataset import Dataset
+
+import os
+import torch
+import fnmatch
+import numpy as np
+import random
+import torch.nn.functional as F
+from .randaugment import ImgAugment
+
+
+class RandomScaleCrop(object):
+ """
+ Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
+ """
+
+ def __init__(self, scale=[1.0, 1.2, 1.5]):
+ self.scale = scale
+
+ def __call__(self, img, label, depth, normal):
+ height, width = img.shape[-2:]
+ sc = self.scale[random.randint(0, len(self.scale) - 1)]
+ h, w = int(height / sc), int(width / sc)
+ i = random.randint(0, height - h)
+ j = random.randint(0, width - w)
+ # pdb.set_trace()
+ img_ = F.interpolate(
+ img[None, :, i : i + h, j : j + w],
+ size=(height, width),
+ mode="bilinear",
+ align_corners=True,
+ ).squeeze(0)
+ label_ = (
+ F.interpolate(
+ label[None, None, i : i + h, j : j + w],
+ size=(height, width),
+ mode="nearest",
+ )
+ .squeeze(0)
+ .squeeze(0)
+ )
+ depth_ = F.interpolate(
+ depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest"
+ ).squeeze(0)
+ normal_ = F.interpolate(
+ normal[None, :, i : i + h, j : j + w],
+ size=(height, width),
+ mode="bilinear",
+ align_corners=True,
+ ).squeeze(0)
+ _sc = sc
+ _h, _w, _i, _j = h, w, i, j
+
+ return (
+ img_,
+ label_,
+ depth_ / sc,
+ normal_,
+ torch.tensor([_sc, _h, _w, _i, _j, height, width]),
+ )
+
+
+class NYUv2_MTL(Dataset):
+ """
+ This file is directly modified from https://pytorch.org/docs/stable/torchvision/datasets.html
+ """
+
+ def __init__(
+ self,
+ root,
+ train=True,
+ ):
+ self.train = train
+ self.root = os.path.expanduser(root)
+ self.augmenter = ImgAugment(
+ num_ops=2, magnitude=6, num_magnitude_bins=31, magnitude_sampling=True
+ )
+
+ # R\read the data file
+ if train:
+ self.data_path = root + "/train"
+ else:
+ self.data_path = root + "/val"
+
+ # calculate data length
+ self.data_len = len(
+ fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy")
+ )
+
+ def __getitem__(self, index):
+ # get image name from the pandas df
+ image = torch.from_numpy(
+ np.moveaxis(
+ np.load(self.data_path + "/image/{:d}.npy".format(index)), -1, 0
+ )
+ )
+ semantic = torch.from_numpy(
+ np.load(self.data_path + "/label/{:d}.npy".format(index))
+ )
+ depth = torch.from_numpy(
+ np.moveaxis(
+ np.load(self.data_path + "/depth/{:d}.npy".format(index)), -1, 0
+ )
+ )
+ normal = torch.from_numpy(
+ np.moveaxis(
+ np.load(self.data_path + "/normal/{:d}.npy".format(index)), -1, 0
+ )
+ )
+
+ if self.train:
+ image_0, semantic_0, depth_0, normal_0, _ = RandomScaleCrop()(
+ image, semantic, depth, normal
+ )
+ image_1, semantic_1, depth_1, normal_1, _ = RandomScaleCrop()(
+ image_0, semantic_0, depth_0, normal_0
+ )
+ image_2, semantic_2, depth_2, normal_2, _ = RandomScaleCrop()(
+ image_0, semantic_0, depth_0, normal_0
+ )
+ image_0 = self.augmenter(image_0)
+ image_1 = self.augmenter(image_1)
+ image_2 = self.augmenter(image_2)
+ return (
+ index,
+ image_0.type(torch.FloatTensor),
+ semantic_0.type(torch.FloatTensor),
+ depth_0.type(torch.FloatTensor),
+ normal_0.type(torch.FloatTensor),
+ image_1.type(torch.FloatTensor),
+ semantic_1.type(torch.FloatTensor),
+ depth_1.type(torch.FloatTensor),
+ normal_1.type(torch.FloatTensor),
+ image_2.type(torch.FloatTensor),
+ semantic_2.type(torch.FloatTensor),
+ depth_2.type(torch.FloatTensor),
+ normal_2.type(torch.FloatTensor),
+ )
+ else:
+ return (
+ image.type(torch.FloatTensor),
+ semantic.type(torch.FloatTensor),
+ depth.type(torch.FloatTensor),
+ normal.type(torch.FloatTensor),
+ )
+
+ def __len__(self):
+ return self.data_len
diff --git a/code/datasets/randaugment.py b/code/datasets/randaugment.py
new file mode 100644
index 0000000..ab54793
--- /dev/null
+++ b/code/datasets/randaugment.py
@@ -0,0 +1,194 @@
+######
+# Augmentation on pytorch tensor
+######
+import torch
+from torch import Tensor
+import math
+import torchvision.transforms as T
+from torchvision.transforms import functional as F, InterpolationMode
+from typing import Dict, List, Optional, Tuple
+
+
+class ImgAugment(torch.nn.Module):
+ """
+ Reference: RandAugment
+ https://pytorch.org/vision/main/_modules/torchvision/transforms/autoaugment.html#RandAugment
+ """
+
+ def __init__(
+ self,
+ num_ops: int = 2,
+ magnitude: int = 6, # default 9, range from 0 ~ num_magnitude_bins
+ num_magnitude_bins: int = 31, # magnitude resolution
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
+ fill: Optional[List[float]] = None,
+ magnitude_sampling=True,
+ ) -> None:
+ super().__init__()
+ self.num_ops = num_ops
+ self.magnitude = magnitude
+ self.num_magnitude_bins = num_magnitude_bins
+ self.interpolation = interpolation
+ self.fill = fill
+ self.magnitude_sampling = magnitude_sampling
+
+ def _augmentation_space(
+ self, num_bins: int, image_size: Tuple[int, int]
+ ) -> Dict[str, Tuple[Tensor, bool]]:
+ return {
+ # op_name: (magnitudes, signed)
+ "Identity": (torch.tensor(0.0), False),
+ # "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
+ # "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
+ # "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
+ # "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
+ # "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
+ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
+ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
+ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
+ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
+ "Posterize": (
+ 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(),
+ False,
+ ),
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
+ "AutoContrast": (torch.tensor(0.0), False),
+ "Equalize": (torch.tensor(0.0), False),
+ }
+
+ def forward(self, img: Tensor) -> Tensor:
+ """
+ img (PIL Image or Tensor): Image to be transformed.
+
+ Returns:
+ PIL Image or Tensor: Transformed image.
+ """
+ fill = self.fill
+ # channels, height, width = F.get_dimensions(img)
+ channels, height, width = img.shape
+ if isinstance(img, Tensor):
+ if isinstance(fill, (int, float)):
+ fill = [float(fill)] * channels
+ elif fill is not None:
+ fill = [float(f) for f in fill]
+
+ op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
+ for _ in range(self.num_ops):
+ op_index = int(torch.randint(len(op_meta), (1,)).item())
+ op_name = list(op_meta.keys())[op_index]
+ magnitudes, signed = op_meta[op_name]
+ if self.magnitude_sampling:
+ sampled_magnitude = torch.randint(self.magnitude, (1,)).item()
+ else:
+ sampled_magnitude = self.magnitude
+ magnitude = (
+ float(magnitudes[sampled_magnitude].item())
+ if magnitudes.ndim > 0
+ else 0.0
+ )
+ if signed and torch.randint(2, (1,)):
+ magnitude *= -1.0
+ img = _apply_op(
+ img, op_name, magnitude, interpolation=self.interpolation, fill=fill
+ )
+
+ return img
+
+
+def _apply_op(
+ img: Tensor,
+ op_name: str,
+ magnitude: float,
+ interpolation: InterpolationMode,
+ fill: Optional[List[float]],
+):
+ if op_name == "ShearX":
+ # magnitude should be arctan(magnitude)
+ # official autoaug: (1, level, 0, 0, 1, 0)
+ # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
+ # compared to
+ # torchvision: (1, tan(level), 0, 0, 1, 0)
+ # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
+ img = F.affine(
+ img,
+ angle=0.0,
+ translate=[0, 0],
+ scale=1.0,
+ shear=[math.degrees(math.atan(magnitude)), 0.0],
+ interpolation=interpolation,
+ fill=fill,
+ center=[0, 0],
+ )
+ elif op_name == "ShearY":
+ # magnitude should be arctan(magnitude)
+ # See above
+ img = F.affine(
+ img,
+ angle=0.0,
+ translate=[0, 0],
+ scale=1.0,
+ shear=[0.0, math.degrees(math.atan(magnitude))],
+ interpolation=interpolation,
+ fill=fill,
+ center=[0, 0],
+ )
+ elif op_name == "TranslateX":
+ img = F.affine(
+ img,
+ angle=0.0,
+ translate=[int(magnitude), 0],
+ scale=1.0,
+ interpolation=interpolation,
+ shear=[0.0, 0.0],
+ fill=fill,
+ )
+ elif op_name == "TranslateY":
+ img = F.affine(
+ img,
+ angle=0.0,
+ translate=[0, int(magnitude)],
+ scale=1.0,
+ interpolation=interpolation,
+ shear=[0.0, 0.0],
+ fill=fill,
+ )
+ elif op_name == "Rotate":
+ img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
+ elif op_name == "Brightness":
+ img = F.adjust_brightness(img, 1.0 + magnitude)
+ elif op_name == "Color":
+ img = F.adjust_saturation(img, 1.0 + magnitude)
+ elif op_name == "Contrast":
+ img = F.adjust_contrast(img, 1.0 + magnitude)
+ elif op_name == "Sharpness":
+ img = F.adjust_sharpness(img, 1.0 + magnitude)
+ elif op_name == "Posterize":
+ if torch.is_tensor(img):
+ img = T.ToPILImage()(img)
+ img = F.posterize(img, int(magnitude))
+ img = T.ToTensor()(img)
+ else:
+ img = F.posterize(img, int(magnitude))
+ elif op_name == "Solarize":
+ if torch.is_tensor(img):
+ img = T.ToPILImage()(img)
+ img = F.solarize(img, magnitude)
+ img = T.ToTensor()(img)
+ else:
+ img = F.solarize(img, magnitude)
+ elif op_name == "AutoContrast":
+ img = F.autocontrast(img)
+ elif op_name == "Equalize":
+ if torch.is_tensor(img):
+ img = T.ToPILImage()(img)
+ img = F.equalize(img)
+ img = T.ToTensor()(img)
+ else:
+ img = F.equalize(img)
+ elif op_name == "Invert":
+ img = F.invert(img)
+ elif op_name == "Identity":
+ pass
+ else:
+ raise ValueError(f"The provided operator {op_name} is not recognized.")
+ return img
diff --git a/code/models/nyuv2_segnet_jtr.py b/code/models/nyuv2_segnet_jtr.py
new file mode 100644
index 0000000..0d36aa9
--- /dev/null
+++ b/code/models/nyuv2_segnet_jtr.py
@@ -0,0 +1,294 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from utils.process import (
+ preprocess_depth_gt,
+ preprocess_normal_gt,
+ preprocess_seg_gt,
+ preprocess_seg_pred,
+ get_jtr_mask_intersection_nyuv2,
+ preprocess_depth_pred,
+ preprocess_normal_pred,
+ get_recon_loss_depth,
+ get_recon_loss_normal,
+ get_recon_loss_seg,
+)
+
+
+class SegNet_JTR(nn.Module):
+ def __init__(self, channel_dims):
+ super(SegNet_JTR, self).__init__()
+
+ self.channel_dims = channel_dims
+ self.num_tasks = len(self.channel_dims)
+ in_channels = sum(channel_dims)
+
+ # initialise network parameters
+ filter = [64, 128, 256, 512, 512]
+
+ # define encoder decoder layers
+ self.encoder_block = nn.ModuleList([self.conv_layer([in_channels, filter[0]])])
+ self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
+ for i in range(4):
+ self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
+ self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
+
+ # define convolution layer
+ self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
+ self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
+ for i in range(4):
+ if i == 0:
+ self.conv_block_enc.append(
+ self.conv_layer([filter[i + 1], filter[i + 1]])
+ )
+ self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
+ else:
+ self.conv_block_enc.append(
+ nn.Sequential(
+ self.conv_layer([filter[i + 1], filter[i + 1]]),
+ self.conv_layer([filter[i + 1], filter[i + 1]]),
+ )
+ )
+ self.conv_block_dec.append(
+ nn.Sequential(
+ self.conv_layer([filter[i], filter[i]]),
+ self.conv_layer([filter[i], filter[i]]),
+ )
+ )
+
+ # define task specific layers
+ self.pred_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
+ ),
+ nn.Conv2d(
+ in_channels=filter[0],
+ out_channels=in_channels,
+ kernel_size=1,
+ padding=0,
+ ),
+ )
+
+ # define pooling and unpooling functions
+ self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
+ self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
+
+ self.batchnorm = nn.BatchNorm2d(num_features=in_channels)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ # define convolutional block
+ def conv_layer(self, channel):
+ conv_block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=channel[0],
+ out_channels=channel[1],
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.BatchNorm2d(num_features=channel[1]),
+ nn.ReLU(inplace=True),
+ )
+ return conv_block
+
+ def forward(self, x, return_feat_only=False):
+ g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
+ [0] * 5 for _ in range(5)
+ )
+ for i in range(5):
+ g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
+
+ # global shared encoder-decoder network
+ for i in range(5):
+ if i == 0:
+ g_encoder[i][0] = self.encoder_block[i](x)
+ g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
+ g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
+ else:
+ g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
+ g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
+ g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
+ feat = [g_maxpool[i]]
+ if return_feat_only:
+ return feat[0]
+ # feat = [g_maxpool]
+ for i in range(5):
+ if i == 0:
+ g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
+ g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
+ g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
+ else:
+ g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
+ g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
+ g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
+
+ feat.append(g_decoder[i][1])
+ # feat.append(g_decoder)
+
+ # define task prediction layers
+ t_pred = self.pred_layer(g_decoder[i][1])
+
+ return feat[0], t_pred
+
+ def compute_losses(
+ self,
+ items_1,
+ items_2,
+ labeled_tasks,
+ num_classes,
+ recon_dist_weight,
+ ):
+ [
+ train_pred_seg_1,
+ train_pred_depth_1,
+ train_pred_normal_1,
+ train_seg_1,
+ train_depth_1,
+ train_normal_1,
+ ] = [
+ items_1[i].detach() if i in labeled_tasks else items_1[i]
+ for i in range(3 * 2)
+ ]
+ [
+ train_pred_seg_2,
+ train_pred_depth_2,
+ train_pred_normal_2,
+ train_seg_2,
+ train_depth_2,
+ train_normal_2,
+ ] = [
+ items_2[i].detach() if i in labeled_tasks else items_2[i]
+ for i in range(3 * 2)
+ ]
+
+ mask_1 = get_jtr_mask_intersection_nyuv2(
+ train_seg_1, train_depth_1, train_normal_1, labeled_tasks
+ )
+ mask_2 = get_jtr_mask_intersection_nyuv2(
+ train_seg_2, train_depth_2, train_normal_2, labeled_tasks
+ )
+
+ Y_hat_1 = torch.cat(
+ [
+ preprocess_seg_pred(train_pred_seg_1, mask_1, hard=True),
+ preprocess_depth_pred(train_pred_depth_1, mask_1),
+ preprocess_normal_pred(train_pred_normal_1, mask_1),
+ ],
+ dim=1,
+ )
+ Y_1 = torch.cat(
+ [
+ (
+ preprocess_seg_gt(train_seg_1, num_classes)
+ if 0 in labeled_tasks
+ else preprocess_seg_pred(train_pred_seg_1, mask_1, hard=True)
+ ),
+ (
+ preprocess_depth_gt(train_depth_1)
+ if 1 in labeled_tasks
+ else preprocess_depth_pred(train_pred_depth_1, mask_1)
+ ),
+ (
+ preprocess_normal_gt(train_normal_1)
+ if 2 in labeled_tasks
+ else preprocess_normal_pred(train_pred_normal_1, mask_1)
+ ),
+ ],
+ dim=1,
+ )
+ Y_hat_2 = torch.cat(
+ [
+ preprocess_seg_pred(train_pred_seg_2, mask_2),
+ preprocess_depth_pred(train_pred_depth_2, mask_2),
+ preprocess_normal_pred(train_pred_normal_2, mask_2),
+ ],
+ dim=1,
+ )
+ Y_2 = torch.cat(
+ [
+ (
+ preprocess_seg_gt(train_seg_2, num_classes)
+ if 0 in labeled_tasks
+ else preprocess_seg_pred(train_pred_seg_2, mask_2, hard=True)
+ ),
+ (
+ preprocess_depth_gt(train_depth_2)
+ if 1 in labeled_tasks
+ else preprocess_depth_pred(train_pred_depth_2, mask_2)
+ ),
+ (
+ preprocess_normal_gt(train_normal_2)
+ if 2 in labeled_tasks
+ else preprocess_normal_pred(train_pred_normal_2, mask_2)
+ ),
+ ],
+ dim=1,
+ )
+
+ copied = copy.deepcopy(self)
+
+ feat_Y_1 = copied(Y_1, return_feat_only=True)
+ feat_Y_hat_1 = copied(Y_hat_1, return_feat_only=True)
+ feat_Y_2, recon_Y_2 = self(Y_2.detach())
+ feat_Y_hat_2, recon_Y_hat_2 = self(Y_hat_2.detach())
+ recon_Y_2_seg, recon_Y_2_depth, recon_Y_2_normal = torch.split(
+ recon_Y_2, self.channel_dims, dim=1
+ )
+ recon_Y_hat_2_seg, recon_Y_hat_2_depth, recon_Y_hat_2_normal = torch.split(
+ recon_Y_hat_2, self.channel_dims, dim=1
+ )
+
+ L_dist = (
+ 1 - F.cosine_similarity(feat_Y_1, feat_Y_hat_1, dim=1, eps=1e-12).mean()
+ ) + recon_dist_weight * (
+ 1 - F.cosine_similarity(feat_Y_2, feat_Y_hat_2, dim=1, eps=1e-12).mean()
+ )
+ L_recon = (
+ get_recon_loss_seg(
+ recon_Y_2_seg,
+ (
+ preprocess_seg_gt(train_seg_2, num_classes)
+ if 0 in labeled_tasks
+ else preprocess_seg_pred(train_pred_seg_2, mask_2, hard=True)
+ ),
+ mask_2,
+ )
+ + get_recon_loss_depth(
+ recon_Y_2_depth,
+ (train_depth_2 if 1 in labeled_tasks else train_pred_depth_2),
+ mask_2,
+ )
+ + get_recon_loss_normal(
+ recon_Y_2_normal,
+ (train_normal_2 if 2 in labeled_tasks else train_pred_normal_2),
+ mask_2,
+ )
+ + get_recon_loss_seg(
+ recon_Y_hat_2_seg,
+ preprocess_seg_pred(train_pred_seg_2, mask_2, hard=True),
+ mask_2,
+ )
+ + get_recon_loss_depth(
+ recon_Y_hat_2_depth,
+ preprocess_depth_pred(train_pred_depth_2, mask_2),
+ mask_2,
+ )
+ + get_recon_loss_normal(
+ recon_Y_hat_2_normal,
+ preprocess_normal_pred(train_pred_normal_2, mask_2),
+ mask_2,
+ )
+ )
+
+ return L_dist, L_recon
diff --git a/code/models/nyuv2_segnet_mtl.py b/code/models/nyuv2_segnet_mtl.py
new file mode 100644
index 0000000..569b114
--- /dev/null
+++ b/code/models/nyuv2_segnet_mtl.py
@@ -0,0 +1,273 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from torch.autograd import Variable
+import torch.nn.init as init
+import numpy as np
+import pdb
+
+# Define SegNet
+# The implementation of SegNet is from https://github.com/lorenmt/mtan
+
+
+class SegNet_MTL(nn.Module):
+ def __init__(self, num_classes=13):
+ super(SegNet_MTL, self).__init__()
+
+ # initialise network parameters
+ filter = [64, 128, 256, 512, 512]
+
+ self.num_classes = num_classes
+
+ # define encoder decoder layers
+ self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
+ self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
+ for i in range(4):
+ self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
+ self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))
+
+ # define convolution layer
+ self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
+ self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
+ for i in range(4):
+ if i == 0:
+ self.conv_block_enc.append(
+ self.conv_layer([filter[i + 1], filter[i + 1]])
+ )
+ self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
+ else:
+ self.conv_block_enc.append(
+ nn.Sequential(
+ self.conv_layer([filter[i + 1], filter[i + 1]]),
+ self.conv_layer([filter[i + 1], filter[i + 1]]),
+ )
+ )
+ self.conv_block_dec.append(
+ nn.Sequential(
+ self.conv_layer([filter[i], filter[i]]),
+ self.conv_layer([filter[i], filter[i]]),
+ )
+ )
+
+ # define task specific layers
+ self.pred_task1 = nn.Sequential(
+ nn.Conv2d(
+ in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
+ ),
+ nn.Conv2d(
+ in_channels=filter[0],
+ out_channels=self.num_classes,
+ kernel_size=1,
+ padding=0,
+ ),
+ )
+ self.pred_task2 = nn.Sequential(
+ nn.Conv2d(
+ in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
+ ),
+ nn.Conv2d(in_channels=filter[0], out_channels=1, kernel_size=1, padding=0),
+ )
+ self.pred_task3 = nn.Sequential(
+ nn.Conv2d(
+ in_channels=filter[0], out_channels=filter[0], kernel_size=3, padding=1
+ ),
+ nn.Conv2d(in_channels=filter[0], out_channels=3, kernel_size=1, padding=0),
+ )
+
+ # define pooling and unpooling functions
+ self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
+ self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ # define convolutional block
+ def conv_layer(self, channel):
+ conv_block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=channel[0],
+ out_channels=channel[1],
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.BatchNorm2d(num_features=channel[1]),
+ nn.ReLU(inplace=True),
+ )
+ return conv_block
+
+ def forward(self, x):
+ g_encoder, g_decoder, g_maxpool, g_upsampl, indices = (
+ [0] * 5 for _ in range(5)
+ )
+ for i in range(5):
+ g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))
+
+ # global shared encoder-decoder network
+ for i in range(5):
+ if i == 0:
+ g_encoder[i][0] = self.encoder_block[i](x)
+ g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
+ g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
+ else:
+ g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
+ g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
+ g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
+ feat = [g_maxpool[i]]
+ # feat = [g_maxpool]
+ for i in range(5):
+ if i == 0:
+ g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
+ g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
+ g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
+ else:
+ g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
+ g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
+ g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
+
+ feat.append(g_decoder[i][1])
+ # feat.append(g_decoder)
+
+ # define task prediction layers
+ t1_pred = self.pred_task1(g_decoder[i][1])
+ t2_pred = self.pred_task2(g_decoder[i][1])
+ t3_pred = self.pred_task3(g_decoder[i][1])
+ t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)
+
+ return [t1_pred, t2_pred, t3_pred], feat
+
+ def compute_losses(
+ self, x_pred1, x_pred2, x_pred3, x_output1, x_output2, x_output3
+ ):
+ # Compute supervised task-specific loss for all tasks when all task labels are available
+
+ # binary mark to mask out undefined pixel space
+ binary_mask = (
+ (torch.sum(x_output2, dim=1) != 0)
+ .type(torch.FloatTensor)
+ .unsqueeze(1)
+ .cuda()
+ )
+ binary_mask_3 = (
+ (torch.sum(x_output3, dim=1) != 0)
+ .type(torch.FloatTensor)
+ .unsqueeze(1)
+ .cuda()
+ )
+
+ # semantic loss: depth-wise cross entropy
+ loss1 = F.nll_loss(F.log_softmax(x_pred1, dim=1), x_output1, ignore_index=-1)
+
+ # depth loss: l1 norm
+ loss2 = torch.sum(torch.abs(x_pred2 - x_output2) * binary_mask) / torch.nonzero(
+ binary_mask
+ ).size(0)
+
+ # normal loss: dot product
+ loss3 = 1 - torch.sum((x_pred3 * x_output3) * binary_mask_3) / torch.nonzero(
+ binary_mask_3
+ ).size(0)
+
+ return [loss1, loss2, loss3]
+
+ # evaluation metircs from https://github.com/lorenmt/mtan
+ def compute_miou(self, x_pred, x_output):
+ _, x_pred_label = torch.max(x_pred, dim=1)
+ x_output_label = x_output
+ batch_size = x_pred.size(0)
+ for i in range(batch_size):
+ true_class = 0
+ first_switch = True
+ for j in range(self.num_classes):
+ pred_mask = torch.eq(
+ x_pred_label[i],
+ j * torch.ones(x_pred_label[i].shape).type(torch.LongTensor).cuda(),
+ )
+ true_mask = torch.eq(
+ x_output_label[i],
+ j
+ * torch.ones(x_output_label[i].shape).type(torch.LongTensor).cuda(),
+ )
+ mask_comb = pred_mask.type(torch.FloatTensor) + true_mask.type(
+ torch.FloatTensor
+ )
+ union = torch.sum((mask_comb > 0).type(torch.FloatTensor))
+ intsec = torch.sum((mask_comb > 1).type(torch.FloatTensor))
+ if union == 0:
+ continue
+ if first_switch:
+ class_prob = intsec / union
+ first_switch = False
+ else:
+ class_prob = intsec / union + class_prob
+ true_class += 1
+ if i == 0:
+ batch_avg = class_prob / true_class
+ else:
+ batch_avg = class_prob / true_class + batch_avg
+ return batch_avg / batch_size
+
+ def compute_iou(self, x_pred, x_output):
+ _, x_pred_label = torch.max(x_pred, dim=1)
+ x_output_label = x_output
+ batch_size = x_pred.size(0)
+ for i in range(batch_size):
+ if i == 0:
+ pixel_acc = torch.div(
+ torch.sum(
+ torch.eq(x_pred_label[i], x_output_label[i]).type(
+ torch.FloatTensor
+ )
+ ),
+ torch.sum((x_output_label[i] >= 0).type(torch.FloatTensor)),
+ )
+ else:
+ pixel_acc = pixel_acc + torch.div(
+ torch.sum(
+ torch.eq(x_pred_label[i], x_output_label[i]).type(
+ torch.FloatTensor
+ )
+ ),
+ torch.sum((x_output_label[i] >= 0).type(torch.FloatTensor)),
+ )
+ return pixel_acc / batch_size
+
+ def depth_error(self, x_pred, x_output):
+ binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).cuda()
+ x_pred_true = x_pred.masked_select(binary_mask)
+ x_output_true = x_output.masked_select(binary_mask)
+ abs_err = torch.abs(x_pred_true - x_output_true)
+ rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true
+ return torch.sum(abs_err) / torch.nonzero(binary_mask).size(0), torch.sum(
+ rel_err
+ ) / torch.nonzero(binary_mask).size(0)
+
+ def normal_error(self, x_pred, x_output):
+ binary_mask = torch.sum(x_output, dim=1) != 0
+ error = (
+ torch.acos(
+ torch.clamp(
+ torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1
+ )
+ )
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ error = np.degrees(error)
+ return (
+ np.mean(error),
+ np.median(error),
+ np.mean(error < 11.25),
+ np.mean(error < 22.5),
+ np.mean(error < 30),
+ )
diff --git a/code/patches/mtpsl.patch b/code/patches/mtpsl.patch
new file mode 100644
index 0000000..f22858a
--- /dev/null
+++ b/code/patches/mtpsl.patch
@@ -0,0 +1,246 @@
+diff --git a/dataset/augmentation.py b/dataset/augmentation.py
+new file mode 100644
+index 0000000..2dfe210
+--- /dev/null
++++ b/dataset/augmentation.py
+@@ -0,0 +1,179 @@
++######
++# Augmentation on pytorch tensor
++######
++import matplotlib.pyplot as plt
++import numpy as np
++import torch
++from torch import Tensor
++import torchvision.transforms as T
++from torchvision.transforms import functional as F, InterpolationMode
++from typing import Dict, List, Optional, Tuple
++
++class ImgAugment(torch.nn.Module):
++ """
++ Reference: RandAugment
++ https://pytorch.org/vision/main/_modules/torchvision/transforms/autoaugment.html#RandAugment
++ """
++ def __init__(
++ self,
++ num_ops: int = 2,
++ magnitude: int = 6, # default 9, range from 0 ~ num_magnitude_bins
++ num_magnitude_bins: int = 31, # magnitude resolution
++ interpolation: InterpolationMode = InterpolationMode.NEAREST,
++ fill: Optional[List[float]] = None,
++ magnitude_sampling=True,
++ ) -> None:
++ super().__init__()
++ self.num_ops = num_ops
++ self.magnitude = magnitude
++ self.num_magnitude_bins = num_magnitude_bins
++ self.interpolation = interpolation
++ self.fill = fill
++ self.magnitude_sampling = magnitude_sampling
++
++ def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
++ return {
++ # op_name: (magnitudes, signed)
++ "Identity": (torch.tensor(0.0), False),
++ # "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
++ # "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
++ # "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
++ # "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
++ # "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
++ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
++ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
++ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
++ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
++ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
++ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
++ "AutoContrast": (torch.tensor(0.0), False),
++ "Equalize": (torch.tensor(0.0), False),
++ }
++
++ def forward(self, img: Tensor) -> Tensor:
++ """
++ img (PIL Image or Tensor): Image to be transformed.
++
++ Returns:
++ PIL Image or Tensor: Transformed image.
++ """
++ fill = self.fill
++ # channels, height, width = F.get_dimensions(img)
++ channels, height, width = img.shape
++ if isinstance(img, Tensor):
++ if isinstance(fill, (int, float)):
++ fill = [float(fill)] * channels
++ elif fill is not None:
++ fill = [float(f) for f in fill]
++
++ op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
++ for _ in range(self.num_ops):
++ op_index = int(torch.randint(len(op_meta), (1,)).item())
++ op_name = list(op_meta.keys())[op_index]
++ magnitudes, signed = op_meta[op_name]
++ if self.magnitude_sampling:
++ sampled_magnitude = torch.randint(self.magnitude, (1,)).item()
++ else:
++ sampled_magnitude = self.magnitude
++ magnitude = float(magnitudes[sampled_magnitude].item()) if magnitudes.ndim > 0 else 0.0
++ if signed and torch.randint(2, (1,)):
++ magnitude *= -1.0
++ img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
++
++ return img
++
++
++
++def _apply_op(
++ img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
++):
++ if op_name == "ShearX":
++ # magnitude should be arctan(magnitude)
++ # official autoaug: (1, level, 0, 0, 1, 0)
++ # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
++ # compared to
++ # torchvision: (1, tan(level), 0, 0, 1, 0)
++ # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
++ img = F.affine(
++ img,
++ angle=0.0,
++ translate=[0, 0],
++ scale=1.0,
++ shear=[math.degrees(math.atan(magnitude)), 0.0],
++ interpolation=interpolation,
++ fill=fill,
++ center=[0, 0],
++ )
++ elif op_name == "ShearY":
++ # magnitude should be arctan(magnitude)
++ # See above
++ img = F.affine(
++ img,
++ angle=0.0,
++ translate=[0, 0],
++ scale=1.0,
++ shear=[0.0, math.degrees(math.atan(magnitude))],
++ interpolation=interpolation,
++ fill=fill,
++ center=[0, 0],
++ )
++ elif op_name == "TranslateX":
++ img = F.affine(
++ img,
++ angle=0.0,
++ translate=[int(magnitude), 0],
++ scale=1.0,
++ interpolation=interpolation,
++ shear=[0.0, 0.0],
++ fill=fill,
++ )
++ elif op_name == "TranslateY":
++ img = F.affine(
++ img,
++ angle=0.0,
++ translate=[0, int(magnitude)],
++ scale=1.0,
++ interpolation=interpolation,
++ shear=[0.0, 0.0],
++ fill=fill,
++ )
++ elif op_name == "Rotate":
++ img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
++ elif op_name == "Brightness":
++ img = F.adjust_brightness(img, 1.0 + magnitude)
++ elif op_name == "Color":
++ img = F.adjust_saturation(img, 1.0 + magnitude)
++ elif op_name == "Contrast":
++ img = F.adjust_contrast(img, 1.0 + magnitude)
++ elif op_name == "Sharpness":
++ img = F.adjust_sharpness(img, 1.0 + magnitude)
++ elif op_name == "Posterize":
++ if torch.is_tensor(img):
++ img = T.ToPILImage()(img)
++ img = F.posterize(img, int(magnitude))
++ img = T.ToTensor()(img)
++ else:
++ img = F.posterize(img, int(magnitude))
++ elif op_name == "Solarize":
++ if torch.is_tensor(img):
++ img = T.ToPILImage()(img)
++ img = F.solarize(img, magnitude)
++ img = T.ToTensor()(img)
++ else:
++ img = F.solarize(img, magnitude)
++ elif op_name == "AutoContrast":
++ img = F.autocontrast(img)
++ elif op_name == "Equalize":
++ if torch.is_tensor(img):
++ img = T.ToPILImage()(img)
++ img = F.equalize(img)
++ img = T.ToTensor()(img)
++ else:
++ img = F.equalize(img)
++ elif op_name == "Invert":
++ img = F.invert(img)
++ elif op_name == "Identity":
++ pass
++ else:
++ raise ValueError(f"The provided operator {op_name} is not recognized.")
++ return img
+diff --git a/dataset/nyuv2ssl.py b/dataset/nyuv2ssl.py
+index ecebd7f..916cc9f 100644
+--- a/dataset/nyuv2ssl.py
++++ b/dataset/nyuv2ssl.py
+@@ -9,7 +9,7 @@ import torchvision.transforms as transforms
+ from PIL import Image
+ import random
+ import torch.nn.functional as F
+-
++from dataset.augmentation import ImgAugment
+
+
+ class RandomScaleCrop(object):
+@@ -77,6 +77,7 @@ class NYUv2_crop(Dataset):
+ self.root = os.path.expanduser(root)
+ self.augmentation = augmentation
+ self.aug_twice = aug_twice
++ self.augmenter = ImgAugment(num_ops = 2, magnitude = 6, num_magnitude_bins=31, magnitude_sampling=True)
+
+
+ # R\read the data file
+@@ -102,6 +103,8 @@ class NYUv2_crop(Dataset):
+ elif self.augmentation and self.aug_twice:
+ image, semantic, depth, normal, _ = RandomScaleCrop()(image, semantic, depth, normal)
+ image1, semantic1, depth1, normal1, trans_params = RandomScaleCrop()(image, semantic, depth, normal)
++ image = self.augmenter(image)
++ image1 = self.augmenter(image1)
+ return image.type(torch.FloatTensor), semantic.type(torch.FloatTensor), depth.type(torch.FloatTensor), normal.type(torch.FloatTensor), index, image1.type(torch.FloatTensor), semantic1.type(torch.FloatTensor), depth1.type(torch.FloatTensor), normal1.type(torch.FloatTensor), trans_params
+ if self.train:
+ return image.type(torch.FloatTensor), semantic.type(torch.FloatTensor), depth.type(torch.FloatTensor), normal.type(torch.FloatTensor), index
+diff --git a/nyu_mtl_xtc.py b/nyu_mtl_xtc.py
+index e81799a..f9f1748 100644
+--- a/nyu_mtl_xtc.py
++++ b/nyu_mtl_xtc.py
+@@ -110,7 +110,7 @@ elif opt.ssl_type == 'randomlabels':
+ nyuv2_train_set = NYUv2_crop(root=dataset_path, train=True, augmentation=True, aug_twice=True)
+ nyuv2_test_set = NYUv2(root=dataset_path, train=False)
+
+-batch_size = 2
++batch_size = 4
+ nyuv2_train_loader = torch.utils.data.DataLoader(
+ dataset=nyuv2_train_set,
+ batch_size=batch_size,
+@@ -162,7 +162,7 @@ for epoch in range(start_epoch, total_epoch):
+ cost_normal = AverageMeter()
+ nyuv2_train_dataset = iter(nyuv2_train_loader)
+ for k in range(train_batch):
+- train_data, train_label, train_depth, train_normal, image_index, train_data1, train_label1, train_depth1, train_normal1, trans_params = nyuv2_train_dataset.next()
++ train_data, train_label, train_depth, train_normal, image_index, train_data1, train_label1, train_depth1, train_normal1, trans_params = next(nyuv2_train_dataset)
+ train_data, train_label = train_data.cuda(), train_label.type(torch.LongTensor).cuda()
+ train_depth, train_normal = train_depth.cuda(), train_normal.cuda()
+ train_data1, train_label1 = train_data1.cuda(), train_label1.type(torch.LongTensor).cuda()
+@@ -264,7 +264,7 @@ for epoch in range(start_epoch, total_epoch):
+ with torch.no_grad(): # operations inside don't track history
+ nyuv2_test_dataset = iter(nyuv2_test_loader)
+ for k in range(test_batch):
+- test_data, test_label, test_depth, test_normal = nyuv2_test_dataset.next()
++ test_data, test_label, test_depth, test_normal = next(nyuv2_test_dataset)
+ test_data, test_label = test_data.cuda(), test_label.type(torch.LongTensor).cuda()
+ test_depth, test_normal = test_depth.cuda(), test_normal.cuda()
+
diff --git a/code/train_nyuv2.py b/code/train_nyuv2.py
new file mode 100644
index 0000000..2e359d2
--- /dev/null
+++ b/code/train_nyuv2.py
@@ -0,0 +1,259 @@
+if __name__ == "__main__":
+ import os
+ import torch
+ import torch.optim as optim
+ from utils import *
+ from models.nyuv2_segnet_mtl import SegNet_MTL
+ from models.nyuv2_segnet_jtr import SegNet_JTR
+ from datasets.nyuv2 import NYUv2_MTL
+ from torch.utils.data import DataLoader
+ from tqdm import tqdm
+
+ parser = get_base_arg_parser("JTR NYUv2 onelabel/randomlabels", {"batch_size": 4})
+
+ parser.add_argument(
+ "--ssl-type",
+ required=True,
+ type=str,
+ help="type of ssl",
+ )
+ parser.add_argument(
+ "--seg-baseline", required=True, type=float, help="Seg. Baseline"
+ )
+ parser.add_argument(
+ "--depth-baseline", required=True, type=float, help="Depth Baseline"
+ )
+ parser.add_argument(
+ "--norm-baseline", required=True, type=float, help="Norm. Baseline"
+ )
+ args = parser.parse_args()
+ num_classes = 13
+ channel_dims = [num_classes, 1, 3]
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+
+ logger = get_logger_nyuv2()
+ mkdir_recursive(args.out_dir)
+ model = SegNet_MTL().cuda()
+ jtr = SegNet_JTR(channel_dims).cuda()
+
+ params = list(model.parameters()) + list(jtr.parameters())
+ optimizer = optim.Adam(params, lr=1e-4)
+ scheduler = optim.lr_scheduler.StepLR(
+ optimizer, step_size=args.step_size, gamma=0.5
+ )
+
+ label_weights = (
+ torch.load(os.path.join(args.label_dir, f"{args.ssl_type}.pth"))[
+ "labels_weights"
+ ]
+ .float()
+ .cuda()
+ )
+ nyuv2_train_set = NYUv2_MTL(
+ root=args.data_dir,
+ train=True,
+ )
+ nyuv2_train_loader = DataLoader(
+ nyuv2_train_set,
+ batch_size=args.batch_size,
+ shuffle=True,
+ num_workers=args.num_workers,
+ )
+ nyuv2_test_set = NYUv2_MTL(
+ root=args.data_dir,
+ train=False,
+ )
+ nyuv2_test_loader = DataLoader(
+ nyuv2_test_set,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.num_workers,
+ )
+
+ best_delta_mtl = float("-inf")
+ for epoch in range(args.num_epochs):
+ bar = tqdm(nyuv2_train_loader)
+ model.train()
+ jtr.train()
+ for batch_idx, (
+ train_idx,
+ train_input_0,
+ train_seg_0,
+ train_depth_0,
+ train_normal_0,
+ train_input_1,
+ train_seg_1,
+ train_depth_1,
+ train_normal_1,
+ train_input_2,
+ train_seg_2,
+ train_depth_2,
+ train_normal_2,
+ ) in enumerate(bar):
+ bar.set_description(f"Epoch {epoch}")
+ (
+ train_input_0,
+ train_seg_0,
+ train_depth_0,
+ train_normal_0,
+ train_input_1,
+ train_seg_1,
+ train_depth_1,
+ train_normal_1,
+ train_input_2,
+ train_seg_2,
+ train_depth_2,
+ train_normal_2,
+ ) = (
+ train_input_0.cuda(),
+ train_seg_0.type(torch.LongTensor).cuda(),
+ train_depth_0.cuda(),
+ train_normal_0.cuda(),
+ train_input_1.cuda(),
+ train_seg_1.type(torch.LongTensor).cuda(),
+ train_depth_1.cuda(),
+ train_normal_1.cuda(),
+ train_input_2.cuda(),
+ train_seg_2.type(torch.LongTensor).cuda(),
+ train_depth_2.cuda(),
+ train_normal_2.cuda(),
+ )
+ train_data_cat = torch.cat(
+ (train_input_0, train_input_1, train_input_2), dim=0
+ )
+ train_pred_cat, _ = model(train_data_cat)
+ train_pred_0, train_pred_1, train_pred_2 = torch.split(
+ torch.cat(train_pred_cat, dim=1), train_idx.size(0), dim=0
+ )
+
+ loss = torch.tensor(0.0, requires_grad=True).cuda()
+
+ for ind in range(train_idx.size(0)):
+ w = label_weights[train_idx[ind].item()].clone().float().cuda()
+ labeled_tasks = torch.nonzero(w).squeeze(1).tolist()
+ train_pred_seg_0, train_pred_depth_0, train_pred_normal_0 = torch.split(
+ train_pred_0[ind].unsqueeze(0), channel_dims, dim=1
+ )
+ train_pred_seg_1, train_pred_depth_1, train_pred_normal_1 = torch.split(
+ train_pred_1[ind].unsqueeze(0), channel_dims, dim=1
+ )
+ train_pred_seg_2, train_pred_depth_2, train_pred_normal_2 = torch.split(
+ train_pred_2[ind].unsqueeze(0), channel_dims, dim=1
+ )
+ train_losses = model.compute_losses(
+ train_pred_seg_0,
+ train_pred_depth_0,
+ train_pred_normal_0,
+ train_seg_0[ind].unsqueeze(0),
+ train_depth_0[ind].unsqueeze(0),
+ train_normal_0[ind].unsqueeze(0),
+ )
+ L_sl = sum([train_losses[i] for i in labeled_tasks]) / train_idx.size(0)
+
+ L_dist, L_recon = jtr.compute_losses(
+ (
+ train_pred_seg_1,
+ train_pred_depth_1,
+ train_pred_normal_1,
+ train_seg_1[ind].unsqueeze(0),
+ train_depth_1[ind].unsqueeze(0),
+ train_normal_1[ind].unsqueeze(0),
+ ),
+ (
+ train_pred_seg_2,
+ train_pred_depth_2,
+ train_pred_normal_2,
+ train_seg_2[ind].unsqueeze(0),
+ train_depth_2[ind].unsqueeze(0),
+ train_normal_2[ind].unsqueeze(0),
+ ),
+ labeled_tasks,
+ num_classes,
+ args.recon_weight,
+ )
+
+ L_jtr = (
+ args.dist_mtl_weight * L_dist + args.recon_weight * L_recon
+ ) / train_idx.size(0)
+
+ loss += L_sl + L_jtr
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ model.eval()
+ conf_matrix = ConfMatrix(num_classes)
+ depth_meter = DepthMeter()
+ normal_meter = NormalsMeter()
+ with torch.no_grad():
+ for (
+ test_input,
+ test_seg,
+ test_depth,
+ test_normal,
+ ) in nyuv2_test_loader:
+ test_input, test_seg, test_depth, test_normal = (
+ test_input.cuda(),
+ test_seg.type(torch.LongTensor).cuda(),
+ test_depth.cuda(),
+ test_normal.cuda(),
+ )
+ test_pred, _ = model(test_input)
+ test_pred_seg, test_pred_depth, test_pred_normal = test_pred
+ conf_matrix.update(
+ test_pred_seg.argmax(1).flatten(), test_seg.flatten()
+ )
+ depth_meter.update(test_pred_depth, test_depth)
+ normal_meter.update(test_pred_normal, test_normal)
+ v_miou, v_acc = conf_matrix.get_metrics()
+ v_depth_scores = depth_meter.get_score()
+ v_l1, v_rel = v_depth_scores["l1"], v_depth_scores["rmse"]
+ v_normal_scores = normal_meter.get_score()
+ v_mean, v_med, v_11_25, v_22_5, v_30 = (
+ v_normal_scores["mean"],
+ v_normal_scores["rmse"],
+ v_normal_scores["11.25"],
+ v_normal_scores["22.5"],
+ v_normal_scores["30"],
+ )
+ delta_seg, delta_depth, delta_norm, delta_mtl = compute_delta_mtl_nyuv2(
+ v_miou,
+ v_l1,
+ v_mean,
+ args.seg_baseline,
+ args.depth_baseline,
+ args.norm_baseline,
+ )
+ is_best = delta_mtl > best_delta_mtl
+ best_delta_mtl = max(delta_mtl, best_delta_mtl)
+ info = [
+ v_miou,
+ v_acc,
+ v_l1,
+ v_rel,
+ v_mean,
+ v_med,
+ v_11_25,
+ v_22_5,
+ v_30,
+ delta_seg,
+ delta_depth,
+ delta_norm,
+ delta_mtl,
+ best_delta_mtl,
+ ]
+ write_logger(logger, epoch, info, args.out_dir, "logs.csv")
+ save_checkpoint(
+ args.out_dir,
+ {
+ "epoch": epoch,
+ "model_state_dict": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ },
+ is_best,
+ )
+ scheduler.step()
diff --git a/code/utils/__init__.py b/code/utils/__init__.py
new file mode 100644
index 0000000..4383e9c
--- /dev/null
+++ b/code/utils/__init__.py
@@ -0,0 +1,6 @@
+from .cli import *
+from .files import *
+from .evals import *
+from .gumbel import *
+from .process import *
+from .deltas import *
diff --git a/code/utils/cli.py b/code/utils/cli.py
new file mode 100644
index 0000000..426a5e7
--- /dev/null
+++ b/code/utils/cli.py
@@ -0,0 +1,27 @@
+import argparse
+from typing import Dict
+
+
+def get_base_arg_parser(description, arg_values: Dict) -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(description=description)
+
+ parser.add_argument("--data-dir", required=True, type=str, help="path to dataset")
+ parser.add_argument(
+ "--out-dir", required=True, type=str, help="path to output directory"
+ )
+ parser.add_argument(
+ "--label-dir", required=True, type=str, help="path to label dir"
+ )
+ parser.add_argument(
+ "--dist-mtl-weight", default=4, type=float, help="dist_mtl_weight"
+ )
+ parser.add_argument("--recon-weight", default=2, type=float, help="recon_weight")
+ parser.add_argument("--num-epochs", default=400, type=int, help="num_epochs")
+ parser.add_argument("--step-size", default=200, type=int, help="step_size")
+ parser.add_argument(
+ "--batch-size", default=arg_values["batch_size"], type=int, help="batch_size"
+ )
+ parser.add_argument("--num-workers", default=4, type=int, help="num_workers")
+ parser.add_argument("--seed", default=0, type=int, help="seed")
+
+ return parser
diff --git a/code/utils/deltas.py b/code/utils/deltas.py
new file mode 100644
index 0000000..678b1a1
--- /dev/null
+++ b/code/utils/deltas.py
@@ -0,0 +1,17 @@
+def compute_delta_mtl_nyuv2(
+ seg, depth, normal, seg_baseline, depth_baseline, normal_baseline
+):
+ seg_delta = (seg * 100 - seg_baseline) / seg_baseline
+ depth_delta = (depth_baseline - depth) / depth_baseline
+ normal_delta = (normal_baseline - normal) / normal_baseline
+ delta_mtl = (seg_delta + depth_delta + normal_delta) / 3
+
+ return seg_delta, depth_delta, normal_delta, delta_mtl
+
+
+def compute_delta_mtl_cityscapes(seg, depth, seg_baseline, depth_baseline):
+ seg_delta = (seg * 100 - seg_baseline) / seg_baseline
+ depth_delta = (depth_baseline - depth) / depth_baseline
+ delta_mtl = (seg_delta + depth_delta) / 2
+
+ return seg_delta, depth_delta, delta_mtl
diff --git a/code/utils/evals.py b/code/utils/evals.py
new file mode 100644
index 0000000..1b85512
--- /dev/null
+++ b/code/utils/evals.py
@@ -0,0 +1,135 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from torch.autograd import Variable
+import torch.nn.init as init
+import numpy as np
+import pdb
+
+# Updated evaluation metrics from https://github.com/lorenmt/mtan and from the SOTA paper https://github.com/SimonVandenhende/Multi-Task-Learning-PyTorch
+
+
+class ConfMatrix(object):
+ def __init__(self, num_classes):
+ self.num_classes = num_classes
+ self.mat = None
+
+ def update(self, pred, target):
+ n = self.num_classes
+ if self.mat is None:
+ self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
+ with torch.no_grad():
+ k = (target >= 0) & (target < n)
+ inds = n * target[k].to(torch.int64) + pred[k]
+ self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
+
+ def get_metrics(self):
+ h = self.mat.float()
+ acc = torch.diag(h).sum() / h.sum()
+ iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
+ return torch.mean(iu).cpu().numpy(), acc.cpu().numpy()
+
+
+class NormalsMeter(object):
+ def __init__(self):
+ self.eval_dict = {
+ "mean": 0.0,
+ "rmse": 0.0,
+ "11.25": 0.0,
+ "22.5": 0.0,
+ "30": 0.0,
+ "n": 0,
+ }
+
+ @torch.no_grad()
+ def update(self, pred, gt):
+ # Performance measurement happens in pixel wise fashion (Same as code from ASTMT (above))
+ valid_mask = torch.sum(gt, dim=1) != 0
+ invalid_mask = torch.sum(gt, dim=1) == 0
+
+ # Calculate difference expressed in degrees
+ deg_diff_tmp = (180 / math.pi) * (
+ torch.acos(
+ torch.clamp(
+ torch.sum(pred * gt, 1).masked_select(valid_mask), min=-1, max=1
+ )
+ )
+ )
+
+ self.eval_dict["mean"] += torch.sum(deg_diff_tmp).item()
+ self.eval_dict["rmse"] += torch.sum(torch.pow(deg_diff_tmp, 2)).item()
+ self.eval_dict["11.25"] += (
+ torch.sum((deg_diff_tmp < 11.25).float()).item() * 100
+ )
+ self.eval_dict["22.5"] += torch.sum((deg_diff_tmp < 22.5).float()).item() * 100
+ self.eval_dict["30"] += torch.sum((deg_diff_tmp < 30).float()).item() * 100
+ self.eval_dict["n"] += deg_diff_tmp.numel()
+
+ def reset(self):
+ self.eval_dict = {
+ "mean": 0.0,
+ "rmse": 0.0,
+ "11.25": 0.0,
+ "22.5": 0.0,
+ "30": 0.0,
+ "n": 0,
+ }
+
+ def get_score(self, verbose=True):
+ eval_result = dict()
+ eval_result["mean"] = self.eval_dict["mean"] / self.eval_dict["n"]
+ eval_result["rmse"] = np.sqrt(self.eval_dict["rmse"] / self.eval_dict["n"])
+ eval_result["11.25"] = self.eval_dict["11.25"] / self.eval_dict["n"]
+ eval_result["22.5"] = self.eval_dict["22.5"] / self.eval_dict["n"]
+ eval_result["30"] = self.eval_dict["30"] / self.eval_dict["n"]
+
+ return eval_result
+
+
+class DepthMeter(object):
+ def __init__(self):
+ self.total_rmses = 0.0
+ self.total_l1 = 0.0
+ self.total_log_rmses = 0.0
+ self.n_valid = 0.0
+ self.num_images = 0.0
+ self.n_valid_image = []
+ self.rmses = []
+
+ @torch.no_grad()
+ def update(self, pred, gt):
+ pred, gt = pred.squeeze(), gt.squeeze()
+ self.num_images += pred.size(0)
+
+ # Determine valid mask
+ mask = (gt != 0).bool()
+ self.n_valid += mask.float().sum().item() # Valid pixels per image
+
+ # Only positive depth values are possible
+ pred = torch.clamp(pred, min=1e-9)
+
+ # Per pixel rmse and log-rmse.
+ log_rmse_tmp = torch.pow(torch.log(gt) - torch.log(pred), 2)
+ log_rmse_tmp = torch.masked_select(log_rmse_tmp, mask)
+ self.total_log_rmses += log_rmse_tmp.sum().item()
+
+ pred = pred.masked_select(mask)
+ gt = gt.masked_select(mask)
+ rmse_tmp = (gt - pred).abs().pow(2).cpu()
+
+ l1_tmp = (gt - pred).abs()
+ self.total_rmses += rmse_tmp.sum().item()
+ self.total_l1 += l1_tmp.sum().item()
+
+ def reset(self):
+ self.rmses = []
+ self.log_rmses = []
+
+ def get_score(self, verbose=True):
+ eval_result = dict()
+ eval_result["rmse"] = np.sqrt(self.total_rmses / self.n_valid)
+ eval_result["l1"] = self.total_l1 / self.n_valid
+ eval_result["log_rmse"] = np.sqrt(self.total_log_rmses / self.n_valid)
+
+ return eval_result
diff --git a/code/utils/files.py b/code/utils/files.py
new file mode 100644
index 0000000..f2cd32e
--- /dev/null
+++ b/code/utils/files.py
@@ -0,0 +1,64 @@
+import os
+import torch
+import shutil
+from pandas import DataFrame
+
+
+def mkdir_recursive(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def save_checkpoint(base, state, is_best, filename="checkpoint.pth.tar"):
+ filename = os.path.join(base, filename)
+ torch.save(state, filename)
+ if is_best:
+ shutil.copyfile(filename, os.path.join(base, "model_best.pth.tar"))
+
+
+def get_logger_nyuv2():
+ df = DataFrame(
+ columns=[
+ "Epoch",
+ "V. Seg. mIoU",
+ "V. Seg. Acc.",
+ "V. Depth Abs.",
+ "V. Depth Rel.",
+ "V. Norm. Mean",
+ "V. Norm. Med.",
+ "V. Norm. 11.25",
+ "V. Norm. 22.5",
+ "V. Norm. 30",
+ "V. Delta Seg.",
+ "V. Delta Depth",
+ "V. Delta Norm.",
+ "V. Delta MTL",
+ "Best Delta MTL",
+ ]
+ )
+ return df
+
+
+def get_logger_cityscapes():
+ df = DataFrame(
+ columns=[
+ "Epoch",
+ "V. Seg. mIoU",
+ "V. Seg. Acc.",
+ "V. Depth Abs.",
+ "V. Depth Rel.",
+ "V. Delta Seg.",
+ "V. Delta Depth",
+ "V. Delta MTL",
+ "Best Delta MTL",
+ ]
+ )
+ return df
+
+
+def write_logger(logger, epoch, info, base, filename="logs.csv"):
+ logger.loc[epoch] = [
+ epoch,
+ *info,
+ ]
+ logger.to_csv(os.path.join(base, filename), index=False)
diff --git a/code/utils/gumbel.py b/code/utils/gumbel.py
new file mode 100644
index 0000000..42ac46a
--- /dev/null
+++ b/code/utils/gumbel.py
@@ -0,0 +1,60 @@
+import torch
+import torch.nn.functional as F
+
+
+def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
+ # type: (Tensor, float, bool, float, int) -> Tensor
+ r"""
+ Samples from the `Gumbel-Softmax distribution`_ and optionally discretizes.
+ You can use this function to replace "F.gumbel_softmax".
+
+ Args:
+ logits: `[..., num_features]` unnormalized log probabilities
+ tau: non-negative scalar temperature
+ hard: if ``True``, the returned samples will be discretized as one-hot vectors,
+ but will be differentiated as if it is the soft sample in autograd
+ dim (int): A dimension along which softmax will be computed. Default: -1.
+ Returns:
+ Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
+ If ``hard=True``, the returned samples will be one-hot, otherwise they will
+ be probability distributions that sum to 1 across `dim`.
+ .. note::
+ This function is here for legacy reasons, may be removed from nn.Functional in the future.
+ .. note::
+ The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft`
+ It achieves two things:
+ - makes the output value exactly one-hot
+ (since we add then subtract y_soft value)
+ - makes the gradient equal to y_soft gradient
+ (since we strip all other gradients)
+ Examples::
+ >>> logits = torch.randn(20, 32)
+ >>> # Sample soft categorical using reparametrization trick:
+ >>> F.gumbel_softmax(logits, tau=1, hard=False)
+ >>> # Sample hard categorical using "Straight-through" trick:
+ >>> F.gumbel_softmax(logits, tau=1, hard=True)
+ .. _Gumbel-Softmax distribution:
+ https://arxiv.org/abs/1611.00712
+ https://arxiv.org/abs/1611.01144
+ """
+
+ def _gen_gumbels():
+ gumbels = -torch.empty_like(logits).exponential_().log()
+ if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum():
+ # to avoid zero in exp output
+ gumbels = _gen_gumbels()
+ return gumbels
+
+ gumbels = _gen_gumbels() # ~Gumbel(0,1)
+ gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
+ y_soft = gumbels.softmax(dim)
+
+ if hard:
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+ else:
+ # Reparametrization trick.
+ ret = y_soft
+ return ret
diff --git a/code/utils/process.py b/code/utils/process.py
new file mode 100644
index 0000000..bfc898e
--- /dev/null
+++ b/code/utils/process.py
@@ -0,0 +1,110 @@
+import torch
+import torch.nn.functional as F
+from .gumbel import gumbel_softmax
+
+_avg = 2.3618
+_std = 1.5876
+
+d = torch.distributions.normal.Normal(_avg, _std)
+
+
+def _transform(x):
+ return d.cdf(x)
+
+
+def _gumbel(pred):
+ return gumbel_softmax(pred, dim=1, tau=1, hard=True)
+
+
+def preprocess_seg_pred(pred, mask, hard=False):
+ return (_gumbel(pred) if hard else F.softmax(pred, dim=1)) * mask
+
+
+def preprocess_depth_pred(pred, mask):
+ return _transform(pred * mask)
+
+
+def preprocess_normal_pred(pred, mask):
+ return (pred * mask + 1.0) / 2.0
+
+
+def get_mask_seg_gt(y):
+ return (y != -1).type(torch.LongTensor).cuda().unsqueeze(0).detach()
+
+
+def get_mask_depth_gt(y):
+ return (
+ (torch.sum(y, dim=1) != 0).type(torch.LongTensor).cuda().unsqueeze(1).detach()
+ )
+
+
+def get_mask_normal_gt(y):
+ return (
+ (torch.sum(y, dim=1) != 0).type(torch.LongTensor).cuda().unsqueeze(1).detach()
+ )
+
+
+def get_jtr_mask_intersection_nyuv2(seg_gt, depth_gt, normal_gt, labeled_tasks):
+ mask = torch.ones_like(seg_gt.unsqueeze(0))
+ if 0 in labeled_tasks:
+ mask *= get_mask_seg_gt(seg_gt)
+ if 1 in labeled_tasks:
+ mask *= get_mask_depth_gt(depth_gt)
+ if 2 in labeled_tasks:
+ mask *= get_mask_normal_gt(normal_gt)
+ return mask.detach()
+
+def get_jtr_mask_intersection_cityscapes(seg_gt, depth_gt, labeled_tasks):
+ mask = torch.ones_like(seg_gt.unsqueeze(0))
+ if 0 in labeled_tasks:
+ mask *= get_mask_seg_gt(seg_gt)
+ if 1 in labeled_tasks:
+ mask *= get_mask_depth_gt(depth_gt)
+ return mask.detach()
+
+def preprocess_seg_gt(y, num_classes):
+ y = y.unsqueeze(0)
+ binary_mask = (y == -1).type(torch.FloatTensor).cuda()
+ y_1 = y.float() * (1 - binary_mask)
+ y_2 = torch.zeros(y.size(0), num_classes, y.size(2), y.size(3)).scatter_(
+ 1, y_1.type(torch.LongTensor), 1
+ ).cuda().detach() * (1 - binary_mask)
+ return y_2
+
+
+def preprocess_depth_gt(y):
+ binary_mask = get_mask_depth_gt(y)
+ return _transform(y * binary_mask)
+
+
+def preprocess_normal_gt(y):
+ return (y + 1.0) / 2.0
+
+
+def _get_detached_and_size(x):
+ return x.detach(), x.nonzero().size(0)
+
+
+def get_recon_loss_seg(x_pred, x_output, mask):
+ x_output, mask_size = _get_detached_and_size(x_output)
+ return (
+ F.binary_cross_entropy_with_logits(x_pred, x_output, reduction="none") * mask
+ ).sum() / mask_size
+
+
+def get_recon_loss_depth(x_pred, x_output, mask):
+ x_output, mask_size = _get_detached_and_size(x_output)
+ x_pred = (torch.sigmoid(x_pred) - 0.5) * (1 - 1e-5) + 0.5
+ x_output = _transform(x_output)
+ return (F.l1_loss(x_pred, x_output, reduction="none") * mask).sum() / mask_size
+
+
+def get_recon_loss_normal(x_pred, x_output, mask):
+ x_output, mask_size = _get_detached_and_size(x_output)
+ return (
+ 1
+ - torch.sum(
+ ((x_pred / torch.norm(x_pred, p=2, dim=1, keepdim=True)) * x_output * mask)
+ )
+ / mask_size
+ )
diff --git a/website/.eslintignore b/website/.eslintignore
new file mode 100644
index 0000000..3897265
--- /dev/null
+++ b/website/.eslintignore
@@ -0,0 +1,13 @@
+.DS_Store
+node_modules
+/build
+/.svelte-kit
+/package
+.env
+.env.*
+!.env.example
+
+# Ignore files for PNPM, NPM and YARN
+pnpm-lock.yaml
+package-lock.json
+yarn.lock
diff --git a/website/.eslintrc.cjs b/website/.eslintrc.cjs
new file mode 100644
index 0000000..0b75758
--- /dev/null
+++ b/website/.eslintrc.cjs
@@ -0,0 +1,31 @@
+/** @type { import("eslint").Linter.Config } */
+module.exports = {
+ root: true,
+ extends: [
+ 'eslint:recommended',
+ 'plugin:@typescript-eslint/recommended',
+ 'plugin:svelte/recommended',
+ 'prettier'
+ ],
+ parser: '@typescript-eslint/parser',
+ plugins: ['@typescript-eslint'],
+ parserOptions: {
+ sourceType: 'module',
+ ecmaVersion: 2020,
+ extraFileExtensions: ['.svelte']
+ },
+ env: {
+ browser: true,
+ es2017: true,
+ node: true
+ },
+ overrides: [
+ {
+ files: ['*.svelte'],
+ parser: 'svelte-eslint-parser',
+ parserOptions: {
+ parser: '@typescript-eslint/parser'
+ }
+ }
+ ]
+};
diff --git a/website/.gitignore b/website/.gitignore
new file mode 100644
index 0000000..6635cf5
--- /dev/null
+++ b/website/.gitignore
@@ -0,0 +1,10 @@
+.DS_Store
+node_modules
+/build
+/.svelte-kit
+/package
+.env
+.env.*
+!.env.example
+vite.config.js.timestamp-*
+vite.config.ts.timestamp-*
diff --git a/website/.npmrc b/website/.npmrc
new file mode 100644
index 0000000..b6f27f1
--- /dev/null
+++ b/website/.npmrc
@@ -0,0 +1 @@
+engine-strict=true
diff --git a/website/.prettierignore b/website/.prettierignore
new file mode 100644
index 0000000..cc41cea
--- /dev/null
+++ b/website/.prettierignore
@@ -0,0 +1,4 @@
+# Ignore files for PNPM, NPM and YARN
+pnpm-lock.yaml
+package-lock.json
+yarn.lock
diff --git a/website/.prettierrc b/website/.prettierrc
new file mode 100644
index 0000000..9573023
--- /dev/null
+++ b/website/.prettierrc
@@ -0,0 +1,8 @@
+{
+ "useTabs": true,
+ "singleQuote": true,
+ "trailingComma": "none",
+ "printWidth": 100,
+ "plugins": ["prettier-plugin-svelte"],
+ "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
+}
diff --git a/website/README.md b/website/README.md
new file mode 100644
index 0000000..5ce6766
--- /dev/null
+++ b/website/README.md
@@ -0,0 +1,38 @@
+# create-svelte
+
+Everything you need to build a Svelte project, powered by [`create-svelte`](https://github.com/sveltejs/kit/tree/main/packages/create-svelte).
+
+## Creating a project
+
+If you're seeing this, you've probably already done this step. Congrats!
+
+```bash
+# create a new project in the current directory
+npm create svelte@latest
+
+# create a new project in my-app
+npm create svelte@latest my-app
+```
+
+## Developing
+
+Once you've created a project and installed dependencies with `npm install` (or `pnpm install` or `yarn`), start a development server:
+
+```bash
+npm run dev
+
+# or start the server and open the app in a new browser tab
+npm run dev -- --open
+```
+
+## Building
+
+To create a production version of your app:
+
+```bash
+npm run build
+```
+
+You can preview the production build with `npm run preview`.
+
+> To deploy your app, you may need to install an [adapter](https://kit.svelte.dev/docs/adapters) for your target environment.
diff --git a/website/package.json b/website/package.json
new file mode 100644
index 0000000..48313fb
--- /dev/null
+++ b/website/package.json
@@ -0,0 +1,37 @@
+{
+ "name": "website",
+ "version": "0.0.1",
+ "private": true,
+ "scripts": {
+ "dev": "vite dev",
+ "build": "vite build",
+ "preview": "vite preview",
+ "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
+ "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
+ "lint": "prettier --check . && eslint .",
+ "format": "prettier --write ."
+ },
+ "devDependencies": {
+ "@sveltejs/adapter-auto": "^3.0.0",
+ "@sveltejs/adapter-static": "^3.0.1",
+ "@sveltejs/kit": "^2.0.0",
+ "@sveltejs/vite-plugin-svelte": "^3.0.0",
+ "@types/eslint": "^8.56.0",
+ "@typescript-eslint/eslint-plugin": "^7.0.0",
+ "@typescript-eslint/parser": "^7.0.0",
+ "eslint": "^8.56.0",
+ "eslint-config-prettier": "^9.1.0",
+ "eslint-plugin-svelte": "^2.35.1",
+ "prettier": "^3.1.1",
+ "prettier-plugin-svelte": "^3.1.2",
+ "svelte": "^4.2.7",
+ "svelte-check": "^3.6.0",
+ "tslib": "^2.4.1",
+ "typescript": "^5.0.0",
+ "vite": "^5.0.3"
+ },
+ "type": "module",
+ "dependencies": {
+ "exio": "^0.6.45"
+ }
+}
diff --git a/website/src/app.d.ts b/website/src/app.d.ts
new file mode 100644
index 0000000..743f07b
--- /dev/null
+++ b/website/src/app.d.ts
@@ -0,0 +1,13 @@
+// See https://kit.svelte.dev/docs/types#app
+// for information about these interfaces
+declare global {
+ namespace App {
+ // interface Error {}
+ // interface Locals {}
+ // interface PageData {}
+ // interface PageState {}
+ // interface Platform {}
+ }
+}
+
+export {};
diff --git a/website/src/app.html b/website/src/app.html
new file mode 100644
index 0000000..14a3eb4
--- /dev/null
+++ b/website/src/app.html
@@ -0,0 +1,38 @@
+
+
+
{bibtex}+ +