This directory contains the source code for the project, excluding the main API scripts the end user should interact with.
core
: Package containing core library components such asMonoDepthModule
orDepthEvaluator
. Can depend on any custom package. Only API scripts should depend on them.datasets
: Package containing PyTorch dataset.devkits
: Package containing basic loading tools for datasets.external_libs
: Package containing external libraries from other developers.losses
: Package containing training losses.networks
: Package containing network architectures (including contribution decoders).regularizers
: Package containing regularizer losses.tools
: Package containing more advanced utilities. They should depend only on each other andutils
.utils
: Package containing basic utilities. They should not have any custom dependencies!__init__.py
:src
package init.paths.py
: File containing path management tools.registry.py
: File containing the tools for registering models & datasets for training.typing.py
: File containing custom type hints.
Please take into account the notes regarding dependencies when deciding where to incorporate custom code.
Contains the core library components required to train and evaluate a Monocular depth estimation network.
evaluator
: Tools for computing predictions over a dataset and evaluating, such aspredict_depths
andMonoDepthEvaluator
.handlers
: Handlers that wrap multi-scale loss computation during training.image_logger
: PyTorch Lightning callback for logging images after each epoch.metrics
: Functions for computing the various sets of evaluation metrics, such aseigen
,benchmark
,pointcloud
andibims
.trainer
: Main PyTorch Lightning module for training,MonoDepthModule
.
The MonoDepthModule
used for training is implemented using PyTorch Lightning, which wraps the optimization procedure and provides hooks to various steps.
See their docs for background info about how the code is organized and what hooks are available.
Overall, the module forward pass is split into:
forward
: Computes the network predictions.forward_postprocess
: Prepares the predictions for loss computation. E.g. upsampling to common resolution & converting to depth.forward_loss
: Computes the optimization loss and produces auxiliary outputs for logging.compute_metrics
: Computes the metrics for logging and validation performance tracking.log_dict
: Logs scalars everyn
steps.image_logger
: Logs images at the end of each epoch.
To add a new network/loss to the training procedure:
- Implement it in the respective module.
- Add it to the
registry
. - Add a new
if
block to the corresponding forward step based on theregistry
key. - If adding a loss, add the corresponding wrapper to
handlers
. - Add auxiliary inputs to
fwd
orloss_dict
for logging. - Add logging to
image_logger
based on the auxiliary inputs.
Configs consist of: networks, losses, datasets, loaders, optimizers, schedulers and trainer. For an example covering most of the avilable options see this file.
- Networks and losses use dictionaries, where the keys correspond to the
registry
keys. Remaining parameters arekwargs
to the respective class. - Losses must add an additional parameter
weight
, which controls the scaling factor in the total loss. - Datasets, optimizers and schedulers add a
type
argument corresponding to theregistry
keys. - Datasets/loaders allow for different configs based on the
train
&val
mode, overriding the original parameters.
# -----------------------------------------------------------------------------
net:
# Depth estimation network.
depth:
enc_name: 'convnext_base' # Choose from `timm` encoders.
pretrained: True
dec_name: 'monodepth' # Choose from custom decoders.
# Pose estimation network (for use with purely monocular models).
pose:
enc_name: 'resnet18' # Typically ResNet18 for efficiency.
pretrained: True
# -----------------------------------------------------------------------------
loss:
# Image-based reconstruction loss.
img_recon:
weight: 1
loss_name: 'ssim'
# -----------------------------------------------------------------------------
dataset:
type: 'kitti_lmdb'
split: 'eigen_zhou' # Can also use `eigen_benchmark`.
size: [ 640, 192 ] # Training images resolution.
supp_idxs: [ -1, 1, 0 ] # Support frames for reconstruction loss. Relative to the target frame. 0 for stereo.
use_depth: True # Needed to evaluate performance throughout training.
train: {mode: 'train', use_aug: True}
val: {mode: 'test', use_benchmark: True, use_aug: False}
# -----------------------------------------------------------------------------
loader:
batch_size: 8
num_workers: 8
drop_last: True
train: { shuffle: True }
val: { shuffle: False }
# -----------------------------------------------------------------------------
optimizer:
type: 'adam' # Choose from any optimizer available from `timm`.
lr: 0.0001
# -----------------------------------------------------------------------------
scheduler:
type: 'steplr'
step_size: 15
gamma: 0.1
# -----------------------------------------------------------------------------
trainer:
max_epochs: 30
resume_training: True # Will begin training from scratch if no checkpoints are found. Otherwise resume.
monitor: 'AbsRel' # Monitor metric to save `best` checkpoint.
min_depth: 0.1 # Min depth to scale sigmoid disparity.
max_depth: 100 # Max depth to scale sigmoid disparity.
benchmark: True # Pytorch cudnn benchmark.
# -----------------------------------------------------------------------------
Contains PyTorch datasets required for training and/or evaluating.
base
:BaseDataset
that all other datasets should inherit from, provides utilities for logging, loading and visualizing.kitti_raw
:KittiRawDataset
for loading Kitti Raw Sync.kitti_raw_lmdb
:KittiRawLmdbDataset
for loading Kitti Raw Sync (LMDB).syns_patches
:SYNSPatchesDataset
for loading SYNS-Patches.
All datasets should inherit from BaseDataset
and implement/override the following methods.
@abstractmethod
def load(self, item: int, x: dict, y: dict, m: dict) -> BatchData:
"""Load data for a single 'item'. MUST return (x, y, m)."""
def augment(self, x: dict, y: dict, m: dict) -> BatchData:
"""Augment a loaded item. Default is a no-op."""
return x, y, m
def transform(self, x: dict, y: dict, m: dict) -> BatchData:
"""Transform a loaded item. Default is a no-op."""
return x, y, m
def to_torch(self, x: dict, y: dict, m: dict) -> BatchData:
"""Convert (x, y, m) to torch Tensors. Default converts to torch and permutes >=3D tensors."""
return ops.to_torch((x, y, m))
@classmethod
def collate_fn(cls, batch: Sequence[BatchData]):
"""Function to collate multiple dataset items. By default uses the PyTorch collator."""
return default_collate(batch)
def create_axs(self) -> Axes:
"""Create the axis structure required for plotting. Assumes data will be in numpy format."""
_, ax = plt.subplots()
return ax
@abstractmethod
def show(self, x: dict, y: dict, m: dict, axs: Optional[Axes] = None) -> None:
"""Show a single dataset item. Should call 'create_axs' if 'axs' is None."""
Datasets must return batches as three dictionaries:
x
: Contains data required for the network forward pass. E.g. images, indexes of support frames.y
: Contains auxiliary data required for loss/metric computation. E.g. depth, edges, non-augmented images.m
: Contains metadata about the loaded batch. E.g. loaded indexes, augmentations applied or errors while loading.
Contains low-level tools for loading and interacting with the available datasets.
kitti_raw
: Tools for loading Kitti Raw Sync.kitti_raw_lmdb
: Tools for loading Kitti Raw Sync (LMDBs).syns_patches
: Tools for loading SYNS-Patches.
Contains libraries from other developers.
- Databases: Tools for creating LMDB datasets.
- Chamfer Distance: C++/CUDA implementation of the Chamfer distance.
The main available losses are:
ReconstructionLoss
: Base view synthesis loss. Additionally used for feature-based view synthesis and autoencoder image reconstruction.RegressionLoss
: Proxy depth regression loss. Additionally used for virtual stereo consistency.
NOTE: Each of these incorporates multiple different contributions based on the available input configuration. Check out the respective documentation for additional details.
New losses should be added as per the instructions in the registry. Losses must return a tuple consisting of
"""
:return (tuple) (
loss: (Tensor) (,) Scalar loss value.
loss_dict: (TensorDict) Dictionary containing intermediate loss outputs used for TensorBoard logging.
)
"""
The main available networks are:
depth
: Predicts a dense disparity map from a single image.pose
: Predicts the relative pose between two images in axis-angle format.autoencoder
: Converts the input image into a compact feature representation, which can be used to reconstruct the image. Used primarily to learn a feature representation complementary to the image-based reconstruction loss.
These networks use any of the pretrained encoders available in timm
.
New networks should be added as per the instructions in the registry.
Networks producing dense outputs (depth & autoencoder) additionally require a dense decoder:
cadepth
: Adds self-attention and channel-wise skip connections. From CA-Depth.ddvnet
. Predicts depth as a discrete disparity volume. From Johnston.diffnet
. Adds self-attention and channel-wise attention skip-connections. From DiffNet.hrdepth
. Adds progressive skip connections & SqueezeExcitation. From HRDepth.monodepth
. Default Conv+ELU+BilinearUpsample. From Monodepth.superdepth
. Conv+ELU+PixelShuffle. From SuperDepth.
Custom decoders should be added to DECODERS
.
Currently, all decoders are required to have roughly the same argument structure.
This could probably be improved by using additional **kwargs
in the main network initializers.
"""
:param num_ch_enc: (Sequence[int]) List of channels per encoder stage.
:param enc_sc: (Sequence[int]) List of downsampling factor per encoder stage.
:param upsample_mode: (str) Torch upsampling mode. {'nearest', 'bilinear'...}
:param use_skip: (bool) If `True`, add skip connections from corresponding encoder stage.
:param out_sc: (Sequence[int]) List of multi-scale output downsampling factor as 2**s.
:param out_ch: (int) Number of output channels.
:param out_act: (str) Activation to apply to each output stage.
"""
Regularizers are meant to prevent suboptimal or degenerate representations, rather than driving the optimization. The main available regularizers are:
MaskReg
: Explainability mask regularization. From SfM-Learner.OccReg
: Disparity occlusion regularization. From DVSO.SmoothReg
: Disparity smoothness regularization. From multiple contributions.FeatPeakReg
: First-order feature peakiness regularization. From FeatDepth.FeatSmoothReg
: Second-order feature smoothness regularization. From FeatDepth.
New regularizers should be added as per the instructions in the registry. They must also follow the output format required by the losses.
A collection of more advanced utilities only depend on each other or on utils
.
geometry
: Depth scaling/conversion and view synthesis tools, such asextract_edges
,to_scaled
,to_inv
,T_from_AAt
,ViewSynth
...ops
: Collection of PyTorch operations, such asto_torch
,to_numpy
,allow_np
,interpolate_like
,expand_dim
...parsers
: Tools for instantiating classes from config dicts.table_formatter
:TableFormatter
to convert dataframes into LaTeX tables.viz
: Visualizations toolsrgb_from_disp
&rgb_from_feat
.
A collection of basic utilities that do not depend on any other custom code from this library.
autoaugment
:TrivialAugmentWide
from PyTorch, modified to only produce photometric augmentations.callbacks
: Custom PyTorch Lighning callbacks, incliding progress bars and anomaly detection.collate
:default_collate
from PyTorch, modified to acceptMultiLevelTimer
.deco
: Custom decorators, includingopt_args_deco
,delegates
, map_container&
retry_new_on_error`.io
: YAML loading/writing tools and image conversion.metrics
: PyTorch Lightning metrics for use during training.misc
: Collection of random utilities,flatten_dict
,sort_dict
,get_logger
&apply_cmap
.timers
:MultiLevelTimer
to allow for nested timing blocks.
Path management for datasets and storing/loading checkpoints is done based on predefined locations in DATA_ROOTS
& MODEL_ROOTS
.
This alleviates the need to provide long and repeated paths, remaining flexible to datasets being stored in different locations (e.g. local scratch spaces).
Instructions for setting up custom roots can be found in the main README.
This file additionally provides some utilities for finding dataset & model paths within the available roots: find_data_dir
& find_model_file
.
These functions will return the input path if it is an absolute path to an existing file/directory.
Otherwise, they will search the available roots and return the first existing path.
New network, losses or datasets should be added to the registry via the register
decorator.
This makes these classes accessible to the parsers
, and in turn to the config files.
import torch.nn as nn
from src import register
@register(name='awesome', type='loss')
class MyAwesomeLoss(nn.Module):
def forward(self, pred, target):
err = (pred - target).abs().mean(dim=1, keepdim=True)
loss = err.mean()
return loss, {'l1_error': err}
type
selects the relevant registry, but can typically be omitted and guessed from the class name.name
represents the identifier used in the configs and module forward pass. Multiple aliases can be registered by providing atuple
, useful when losses share the same underlying computations but require different inputs or preprocessing inMonoDepthModule
. An example is the baseReconstructionLoss
, which can be used with either images or dense feature maps.
Contains custom type hints and details on the config formats. Some highlights include:
ArrDict
: Dict mapping fromstr
tonp.ndarray
.TensorDict
: Dict mapping fromstr
totorch.Tensor
.BatchData
: Return type expected by dataset, consisting ofx
(network inputs),y
(auxiliary data for losses) &m
(batch metadata).LossData
: Return type expected by losses, consisting of a scalar loss and a dictionary with intermediate tensors for logging.MonoDepthCfg
: Config structure required by theMonoDepthModule
.