diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/README.md b/README.md index 3c151f4..f1b3c1a 100644 --- a/README.md +++ b/README.md @@ -14,4 +14,67 @@ Thirdly, we present a new method for NCD based on online clustering that exploit Lastly, we introduce a new evaluation protocol to assess the performance of NCD for point cloud semantic segmentation. We thoroughly evaluate our method on SemanticKITTI and SemanticPOSS datasets, showing that it can significantly outperform the baseline. -Camera ready and code will be released soon! +:fire: For more information have a look at our [PAPER](https://arxiv.org/pdf/2303.11610)! :fire: + +Authors: + Luigi Riz, + [Cristiano Saltori](https://scholar.google.com/citations?user=PID7Z4oAAAAJ&hl), + [Elisa Ricci](https://scholar.google.ca/citations?user=xf1T870AAAAJ&hl), + [Fabio Poiesi](https://scholar.google.co.uk/citations?user=BQ7li6AAAAAJ&hl) + +## News :new: +- 3/2023: NOPS code is **OUT**!:fire: +- 3/2023: NOPS is accepted to CVPR 2023!:fire: Our work is the first allowing the segmentation of known and unknown classes in 3D Lidar scans! + +## Installation + +The code has been tested with Python 3.8, CUDA 11.3, pytorch 1.10.1 and pytorch-lighting 1.4.8. Any other version may require to update the code for compatibility. + +### Conda +To run the code, you need to install: +- [Pytorch 1.10.1](https://pytorch.org/get-started/previous-versions/) +- [Minkowski Engine](https://github.com/NVIDIA/MinkowskiEngine) +- [Pytorch-Lighting 1.4.8](https://www.pytorchlightning.ai) (be sure to install torchmetrics=0.7.2) +- [Scipy 1.7.3](https://scipy.org/install/) +- [Wandb](https://docs.wandb.ai/quickstart) + +## Data preparation +To download the data follow the instructions provided by [SemanticKITTI](http://www.semantic-kitti.org) and [SemanticPOSS](http://www.poss.pku.edu.cn/semanticposs.html). Then, use this structure of the folders: +``` +./ +├── +├── ... +└── path_to_data_shown_in_yaml_config/ + └── sequences + ├── 00/ + │ ├── velodyne/ + | | ├── 000000.bin + | | ├── 000001.bin + | | └── ... + │ └── labels/ + | ├── 000000.label + | ├── 000001.label + | └── ... + └── ... +``` + +## Commands +### Pretraining +To run the pretraining: +``` +python main_pretrain.py -s [SPLIT NUMBER] --dataset [SemanticPOSS, SemanticKITTI] +``` +For additional command line arguments, run: +``` +python main_pretrain.py -h +``` + +### Discovery +To run the discovery step (pretraining is not mandatory): +``` +python main_discover.py -s [SPLIT NUMBER] --dataset [SemanticPOSS, SemanticKITTI] +``` +For additional command line arguments, run: +``` +python main_discover.py -h +``` diff --git a/config/semkitti_dataset.yaml b/config/semkitti_dataset.yaml new file mode 100644 index 0000000..7394c2c --- /dev/null +++ b/config/semkitti_dataset.yaml @@ -0,0 +1,198 @@ +dataset_path: /data/disk1/share/luriz/datasets/SemanticKITTI/dataset/ + +folder_name: + input: velodyne + label: labels + +split_sequence: # sequence numbers + train: + - '00' + - '01' + - '02' + - '03' + - '04' + - '05' + - '06' + - '07' + - '09' + - '10' + valid: + - '08' + test: + - '11' + - '12' + - '13' + - '14' + - '15' + - '16' + - '17' + - '18' + - '19' + - '20' + - '21' + +learning_map: + 0 : -1 # "unlabeled" + 1 : -1 # "outlier" mapped to "unlabeled" --------------------------mapped + 10: 0 # "car" + 11: 1 # "bicycle" + 13: 4 # "bus" mapped to "other-vehicle" --------------------------mapped + 15: 2 # "motorcycle" + 16: 4 # "on-rails" mapped to "other-vehicle" ---------------------mapped + 18: 3 # "truck" + 20: 4 # "other-vehicle" + 30: 5 # "person" + 31: 6 # "bicyclist" + 32: 7 # "motorcyclist" + 40: 8 # "road" + 44: 9 # "parking" + 48: 10 # "sidewalk" + 49: 11 # "other-ground" + 50: 12 # "building" + 51: 13 # "fence" + 52: -1 # "other-structure" mapped to "unlabeled" ------------------mapped + 60: 8 # "lane-marking" to "road" ---------------------------------mapped + 70: 14 # "vegetation" + 71: 15 # "trunk" + 72: 16 # "terrain" + 80: 17 # "pole" + 81: 18 # "traffic-sign" + 99: -1 # "other-object" to "unlabeled" ----------------------------mapped + 252: 0 # "moving-car" to "car" ------------------------------------mapped + 253: 6 # "moving-bicyclist" to "bicyclist" ------------------------mapped + 254: 5 # "moving-person" to "person" ------------------------------mapped + 255: 7 # "moving-motorcyclist" to "motorcyclist" ------------------mapped + 256: 4 # "moving-on-rails" mapped to "other-vehicle" --------------mapped + 257: 4 # "moving-bus" mapped to "other-vehicle" -------------------mapped + 258: 3 # "moving-truck" to "truck" --------------------------------mapped + 259: 4 # "moving-other"-vehicle to "other-vehicle" ----------------mapped + +learning_map_inv: # inverse of previous map + -1: 0 # "unlabeled", and others ignored + 0: 10 # "car" + 1: 11 # "bicycle" + 2: 15 # "motorcycle" + 3: 18 # "truck" + 4: 20 # "other-vehicle" + 5: 30 # "person" + 6: 31 # "bicyclist" + 7: 32 # "motorcyclist" + 8: 40 # "road" + 9: 44 # "parking" + 10: 48 # "sidewalk" + 11: 49 # "other-ground" + 12: 50 # "building" + 13: 51 # "fence" + 14: 70 # "vegetation" + 15: 71 # "trunk" + 16: 72 # "terrain" + 17: 80 # "pole" + 18: 81 # "traffic-sign" + +color_map: # bgr + 0 : [0, 0, 0] + 1 : [0, 0, 255] + 10: [245, 150, 100] + 11: [245, 230, 100] + 13: [250, 80, 100] + 15: [150, 60, 30] + 16: [255, 0, 0] + 18: [180, 30, 80] + 20: [255, 0, 0] + 30: [30, 30, 255] + 31: [200, 40, 255] + 32: [90, 30, 150] + 40: [255, 0, 255] + 44: [255, 150, 255] + 48: [75, 0, 75] + 49: [75, 0, 175] + 50: [0, 200, 255] + 51: [50, 120, 255] + 52: [0, 150, 255] + 60: [170, 255, 150] + 70: [0, 175, 0] + 71: [0, 60, 135] + 72: [80, 240, 150] + 80: [150, 240, 255] + 81: [0, 0, 255] + 99: [255, 255, 50] + 252: [245, 150, 100] + 256: [255, 0, 0] + 253: [200, 40, 255] + 254: [30, 30, 255] + 255: [90, 30, 150] + 257: [250, 80, 100] + 258: [180, 30, 80] + 259: [255, 0, 0] + +labels: + 0 : "unlabeled" + 1 : "outlier" + 10: "car" + 11: "bicycle" + 13: "bus" + 15: "motorcycle" + 16: "on-rails" + 18: "truck" + 20: "other-vehicle" + 30: "person" + 31: "bicyclist" + 32: "motorcyclist" + 40: "road" + 44: "parking" + 48: "sidewalk" + 49: "other-ground" + 50: "building" + 51: "fence" + 52: "other-structure" + 60: "lane-marking" + 70: "vegetation" + 71: "trunk" + 72: "terrain" + 80: "pole" + 81: "traffic-sign" + 99: "other-object" + 252: "moving-car" + 253: "moving-bicyclist" + 254: "moving-person" + 255: "moving-motorcyclist" + 256: "moving-on-rails" + 257: "moving-bus" + 258: "moving-truck" + 259: "moving-other-vehicle" + +content: # as a ratio with the total number of points + 0: 0.018889854628292943 + 1: 0.0002937197336781505 + 10: 0.040818519255974316 + 11: 0.00016609538710764618 + 13: 2.7879693665067774e-05 + 15: 0.00039838616015114444 + 16: 0.0 + 18: 0.0020633612104619787 + 20: 0.0016218197275284021 + 30: 0.00017698551338515307 + 31: 1.1065903904919655e-08 + 32: 5.532951952459828e-09 + 40: 0.1987493871255525 + 44: 0.014717169549888214 + 48: 0.14392298360372 + 49: 0.0039048553037472045 + 50: 0.1326861944777486 + 51: 0.0723592229456223 + 52: 0.002395131480328884 + 60: 4.7084144280367186e-05 + 70: 0.26681502148037506 + 71: 0.006035012012626033 + 72: 0.07814222006271769 + 80: 0.002855498193863172 + 81: 0.0006155958086189918 + 99: 0.009923127583046915 + 252: 0.001789309418528068 + 253: 0.00012709999297008662 + 254: 0.00016059776092534436 + 255: 3.745553104802113e-05 + 256: 0.0 + 257: 0.00011351574470342043 + 258: 0.00010157861367183268 + 259: 4.3840131989471124e-05 \ No newline at end of file diff --git a/config/semposs_dataset.yaml b/config/semposs_dataset.yaml new file mode 100644 index 0000000..009a89f --- /dev/null +++ b/config/semposs_dataset.yaml @@ -0,0 +1,125 @@ +dataset_path: /data/disk1/share/luriz/datasets/SemanticPOSS/dataset/ + +folder_name: + input: velodyne + label: labels + +split_sequence: # sequence numbers + train: + - '00' + - '01' + - '02' + - '04' + - '05' + valid: + - '03' + +learning_map: + 0: -1 # unlabeled + 1: -1 # unlabeled -------------------------------mapped + 2: -1 # unlabeled -------------------------------mapped + 3: -1 # unlabeled -------------------------------mapped + 4: 0 # 1 person + 5: 0 # 2+ person --------------------------------mapped + 6: 1 # rider + 7: 2 # car + 8: 3 # trunk + 9: 4 # plants + 10: 5 # traffic sign 1 + 11: 5 # traffic sign 2 --------------------------mapped + 12: 5 # traffic sign 3 --------------------------mapped + 13: 6 # pole + 14: 7 # trashcan + 15: 8 # building + 16: 9 # cone/stone + 17: 10 # fence + 18: -1 # unlabeled -------------------------------mapped + 19: -1 # unlabeled -------------------------------mapped + 20: -1 # unlabeled -------------------------------mapped + 21: 11 # bike + 22: 12 # other-ground + +learning_map_inv: # inverse of previous map + -1: 0 # "unlabeled", and others ignored + 0: 4 # "person" + 1: 6 # "rider" + 2: 7 # "car" + 3: 8 # "trunk" + 4: 9 # "plants" + 5: 10 # "traffic-sign" + 6: 13 # "pole" + 7: 14 # "trashcan" + 8: 15 # "building" + 9: 16 # "cone/stone" + 10: 17 # "fence" + 11: 21 # "bike" + 12: 22 # "other-ground" + +color_map: # bgr + 0: [128, 128, 128] # unlabeled + 1: [0, 0, 0] # unlabeled + 2: [0, 0, 0] # unlabeled + 3: [0, 0, 0] # unlabeled + 4: [255, 30, 30] # 1 person + 5: [255, 30, 30] # 2+ person + 6: [255, 40, 200] # rider + 7: [100, 150, 245] # car + 8: [135,60,0] # trunk + 9: [0, 175, 0] # plants + 10: [255, 0, 0] # traffic sign 1 # standing sign + 11: [255, 0, 0] # traffic sign 2 # hanging sign + 12: [255, 0, 0] # traffic sign 3 # high/big hanging sign + 13: [255, 240, 150] # pole + 14: [125, 255, 0] # trashcan + 15: [255, 200, 0] # building + 16: [50, 255, 255] # cone/stone + 17: [255, 120, 50] # fence + 18: [0, 0, 0] # unlabeled + 19: [0, 0, 0] # unlabeled + 20: [0, 0, 0] # unlabeled + 21: [100, 230, 245] # bike + 22: [0, 0, 0] # other-ground + +labels: + 0: "unlabeled" + 4: "person" + 5: "2+ person" + 6: "rider" + 7: "car" + 8: "trunk" + 9: "plants" + 10: "traffic sign" # standing sign + 11: "traffic sign 2" # hanging sign + 12: "traffic sign 3" # high/big hanging sign + 13: "pole" + 14: "trashcan" + 15: "building" + 16: "cone-stone" + 17: "fence" + 21: "bike" + 22: "other-ground" + +content: # as a ratio with the total number of points + 0: 2.14244059e-02 + 1: 2.59110680e-08 + 2: 0.0 + 3: 0.0 + 4: 1.45552885e-02 + 5: 2.97170930e-03 + 6: 4.27795878e-03 + 7: 7.47442017e-02 + 8: 1.19028088e-02 + 9: 3.64772113e-01 + 10: 2.93810782e-03 + 11: 1.22949054e-03 + 12: 1.45168831e-03 + 13: 4.73977693e-03 + 14: 8.20411782e-04 + 15: 2.22458412e-01 + 16: 9.10416468e-04 + 17: 1.53973464e-02 + 18: 7.59448221e-04 + 19: 3.80011723e-05 + 20: 3.34610350e-04 + 21: 5.48716718e-02 + 22: 1.99402106e-01 \ No newline at end of file diff --git a/main_discover.py b/main_discover.py new file mode 100644 index 0000000..3b9dc5b --- /dev/null +++ b/main_discover.py @@ -0,0 +1,132 @@ +import os +from argparse import ArgumentParser +from datetime import datetime + +import numpy as np +import pytorch_lightning as pl +import torch +import yaml +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import CSVLogger, WandbLogger + +from modules.Discoverer import Discoverer +from utils import unkn_labels as unk_labels +from utils.callbacks import mIoUEvaluatorCallback + +SEED = 1234 + +parser = ArgumentParser() +parser.add_argument("-s", "--split", type=int, help="split", required=True) +parser.add_argument("--dataset", choices=["SemanticKITTI", "SemanticPOSS"], default="SemanticPOSS", type=str, help="dataset") +parser.add_argument("--dataset_config", default=None, type=str, help="dataset config file") +parser.add_argument("--voxel_size", default="0.05", type=float, help="voxel_size") +parser.add_argument("--downsampling", default="60000", type=int, help="number of points per pcd") +parser.add_argument("--batch_size", default=4, type=int, help="batch size") +parser.add_argument("--num_workers", default=8, type=int, help="number of workers") +parser.add_argument("--hungarian_at_each_step", default=True, action="store_true", help="enable hungarian pass at the end of each epoch") +parser.add_argument("--log_dir", default="logs", type=str, help="log directory") +parser.add_argument("--checkpoint_dir", default="checkpoints", type=str, help="checkpoint dir") +parser.add_argument("--use_uncertainty_queue", default=False, action="store_true", help="use uncertainty modeling for the queue") +parser.add_argument("--use_uncertainty_loss", default=False, action="store_true", help="use uncertainty modeling for the point evaluated during the loss") +parser.add_argument("--uncertainty_percentile", default=0.5, type=float, help="percentile used in uncertainty modeling") +parser.add_argument("--train_lr", default=1.0e-2, type=float, help="learning rate for newly initialized parts of the pipeline") +parser.add_argument("--finetune_lr", default=1.0e-4, type=float, help="learning rate for already initialized parts of the pipeline") +parser.add_argument("--use_scheduler", default=False, action="store_true", help="use lr scheduler (linear warm-up + cosine_annealing") +parser.add_argument("--warmup_epochs", default=4, type=int, help="warmup epochs") +parser.add_argument("--min_lr", default=1e-5, type=float, help="min learning rate") +parser.add_argument("--momentum_for_optim", default=0.9, type=float, help="momentum for optimizer") +parser.add_argument("--weight_decay_for_optim", default=1.0e-4, type=float, help="weight decay") +parser.add_argument("--overcluster_factor", default=None, type=int, help="overclustering factor") +parser.add_argument("--num_heads", default=1, type=int, help="number of heads for clustering") +parser.add_argument("--clear_cache_int", default=1, type=int, help="frequency of clear_cache") +parser.add_argument("--num_iters_sk", default=3, type=int, help="number of iters for Sinkhorn") +parser.add_argument("--initial_epsilon_sk", default=0.3, type=float, help="initial epsilon for the Sinkhorn") +parser.add_argument("--final_epsilon_sk", default=0.05, type=float, help="final epsilon for the Sinkhorn") +parser.add_argument("--adapting_epsilon_sk", default=False, action="store_true", help="use a decreasing value of epsilon for Sinkhorn") +parser.add_argument("--queue_start_epoch", default=2, type=int, help="the epoch in which to start to use the queue. -1 to never use the queue") +parser.add_argument("--queue_batches", default=10, type=int, help="umber of batches in the queue") +parser.add_argument("--queue_percentage", default=0.1, type=float, help="percentage of novel points per batch retained in the queue") +parser.add_argument("--comment", default=datetime.now().strftime("%b%d_%H-%M-%S"), type=str) +parser.add_argument("--project", default="NOPS", type=str, help="wandb project") +parser.add_argument("--entity", default="luigiriz", type=str, help="wandb entity") +parser.add_argument("--offline", default=False, action="store_true", help="disable wandb") +parser.add_argument("--pretrained", type=str, help="pretrained checkpoint path") +parser.add_argument("--epochs", type=int, default=10, help="training epochs") +parser.add_argument("--set_deterministic", default=False, action="store_true") + + +def main(args): + + if args.set_deterministic: + os.environ["PYTHONHASHSEED"] = str(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + torch.backends.cudnn.benchmark = True + + if not os.path.exists(args.checkpoint_dir): + os.mkdir(args.checkpoint_dir) + + if not os.path.exists(args.log_dir): + os.mkdir(args.log_dir) + + run_name = "-".join([f"S{args.split}", "discover", args.dataset, args.comment]) + wandb_logger = WandbLogger( + save_dir=args.log_dir, + name=run_name, + project=args.project, + entity=args.entity, + offline=args.offline, + ) + + if args.dataset_config is None: + if args.dataset == "SemanticKITTI": + args.dataset_config = "config/semkitti_dataset.yaml" + elif args.dataset == "SemanticPOSS": + args.dataset_config = "config/semposs_dataset.yaml" + else: + raise NameError(f"Dataset {args.dataset} not implemented") + + with open(args.dataset_config, "r") as f: + dataset_config = yaml.safe_load(f) + + unknown_labels = unk_labels.unknown_labels( + split=args.split, dataset_config=dataset_config + ) + + number_of_unk = len(unknown_labels) + + label_mapping, label_mapping_inv, unknown_label = unk_labels.label_mapping( + unknown_labels, dataset_config["learning_map_inv"].keys() + ) + + args.num_classes = len(label_mapping) + args.num_unlabeled_classes = number_of_unk + args.num_labeled_classes = args.num_classes - args.num_unlabeled_classes + + mIoU_callback = mIoUEvaluatorCallback() + checkpoint_callback = ModelCheckpoint( + save_top_k=-1, + save_weights_only=True, + dirpath=args.checkpoint_dir, + every_n_epochs=True, + ) + csv_logger = CSVLogger(save_dir=args.log_dir) + + loggers = [wandb_logger, csv_logger] if wandb_logger is not None else [csv_logger] + + model = Discoverer(label_mapping, label_mapping_inv, unknown_label, **args.__dict__) + trainer = pl.Trainer( + max_epochs=args.epochs, + logger=loggers, + gpus=-1, + num_sanity_val_steps=0, + callbacks=[mIoU_callback, checkpoint_callback], + ) + trainer.fit(model) + + +if __name__ == "__main__": + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/main_pretrain.py b/main_pretrain.py new file mode 100644 index 0000000..64a249f --- /dev/null +++ b/main_pretrain.py @@ -0,0 +1,104 @@ +import os +from argparse import ArgumentParser +from datetime import datetime + +import numpy as np +import pytorch_lightning as pl +import torch +import yaml +from pytorch_lightning.loggers import CSVLogger, WandbLogger + +from modules.Pretrainer import Pretrainer +from utils import unkn_labels as unk_labels +from utils.callbacks import PretrainCheckpointCallback, mIoUEvaluatorCallback + +SEED = 1234 + +parser = ArgumentParser() +parser.add_argument("-s", "--split", type=int, help="split", required=True) +parser.add_argument("--dataset", choices=['SemanticKITTI', 'SemanticPOSS'], default="SemanticPOSS", type=str, help="dataset") +parser.add_argument("--dataset_config", default=None, type=str, help="dataset config file") +parser.add_argument("--voxel_size", default="0.05", type=float, help="voxel_size") +parser.add_argument("--downsampling", default="60000", type=int, help="number of points per pcd") +parser.add_argument("--batch_size", default=8, type=int, help="batch size") +parser.add_argument("--num_workers", default=8, type=int, help="number of workers") +parser.add_argument("--log_dir", default="logs", type=str, help="log directory") +parser.add_argument("--checkpoint_dir", default="checkpoints_pretraining", type=str, help="checkpoint dir") +parser.add_argument("--train_lr", default=1.0e-2, type=float, help="learning rate for newly initialized parts of the pipeline") +parser.add_argument("--momentum_for_optim", default=0.9, type=float, help="momentum for optimizer") +parser.add_argument("--weight_decay_for_optim", default=1.0e-4, type=float, help="weight decay") +parser.add_argument("--clear_cache_int", default=1, type=int, help="frequency of clear_cache") +parser.add_argument("--comment", default=datetime.now().strftime("%b%d_%H-%M-%S"), type=str) +parser.add_argument("--project", default="NOPS", type=str, help="wandb project") +parser.add_argument("--entity", default="luigiriz", type=str, help="wandb entity") +parser.add_argument("--offline", default=False, action="store_true", help="disable wandb") +parser.add_argument("--epochs", type=int, default=20, help="training epochs") +parser.add_argument("--set_deterministic", default=False, action="store_true") + +def main(args): + + if args.set_deterministic: + os.environ["PYTHONHASHSEED"] = str(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + torch.backends.cudnn.benchmark = True + + if not os.path.exists(args.checkpoint_dir): + os.mkdir(args.checkpoint_dir) + + if not os.path.exists(args.log_dir): + os.mkdir(args.log_dir) + + run_name = "-".join([f'S{args.split}', "pretrain", args.dataset, args.comment]) + wandb_logger = WandbLogger( + save_dir=args.log_dir, + name=run_name, + project=args.project, + entity=args.entity, + offline=args.offline, + ) + + if args.dataset_config is None: + if args.dataset == 'SemanticKITTI': + args.dataset_config = 'config/semkitti_dataset.yaml' + elif args.dataset == 'SemanticPOSS': + args.dataset_config = 'config/semposs_dataset.yaml' + else: + raise NameError(f'Dataset {args.dataset} not implemented') + + with open(args.dataset_config, 'r') as f: + dataset_config = yaml.safe_load(f) + + unknown_labels = unk_labels.unknown_labels( + split=args.split, dataset_config=dataset_config) + + number_of_unk=len(unknown_labels) + + label_mapping, label_mapping_inv, unknown_label = unk_labels.label_mapping( + unknown_labels, dataset_config['learning_map_inv'].keys()) + + args.num_classes = len(label_mapping) + args.num_unlabeled_classes = number_of_unk + args.num_labeled_classes = args.num_classes - args.num_unlabeled_classes + + mIoU_callback = mIoUEvaluatorCallback() + pretrain_checkpoint_callback = PretrainCheckpointCallback() + csv_logger = CSVLogger(save_dir=args.log_dir) + + model = Pretrainer(label_mapping, label_mapping_inv, unknown_label, **args.__dict__) + loggers = [wandb_logger, csv_logger] if wandb_logger is not None else [csv_logger] + + trainer = pl.Trainer( + max_epochs=args.epochs, + logger=loggers, + gpus=-1, + callbacks=[mIoU_callback, pretrain_checkpoint_callback] + ) + trainer.fit(model) + + +if __name__ == "__main__": + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/minkunet.py b/models/minkunet.py new file mode 100644 index 0000000..e8891a6 --- /dev/null +++ b/models/minkunet.py @@ -0,0 +1,206 @@ +# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural +# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part +# of the code. +import MinkowskiEngine as ME + +from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck + +from models.resnet import ResNetBase + + +class MinkUNetBase(ResNetBase): + BLOCK = None + PLANES = None + DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) + INIT_DIM = 32 + OUT_TENSOR_STRIDE = 1 + + # To use the model, must call initialize_coords before forward pass. + # Once data is processed, call clear to reset the model before calling + # initialize_coords + def __init__(self, in_channels, out_channels, D=3): + ResNetBase.__init__(self, in_channels, out_channels, D) + + def network_initialization(self, in_channels, out_channels, D): + # Output of the first conv concated to conv6 + self.inplanes = self.INIT_DIM + self.conv0p1s1 = ME.MinkowskiConvolution( + in_channels, self.inplanes, kernel_size=5, dimension=D) + + self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) + + self.conv1p1s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) + + self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], + self.LAYERS[0]) + + self.conv2p2s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) + + self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], + self.LAYERS[1]) + + self.conv3p4s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + + self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) + self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], + self.LAYERS[2]) + + self.conv4p8s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) + self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], + self.LAYERS[3]) + + self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D) + self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) + + self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion + self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], + self.LAYERS[4]) + self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D) + self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) + + self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion + self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], + self.LAYERS[5]) + self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D) + self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) + + self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion + self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], + self.LAYERS[6]) + self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D) + self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) + + self.inplanes = self.PLANES[7] + self.INIT_DIM + self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], + self.LAYERS[7]) + + self.final = ME.MinkowskiConvolution( + self.PLANES[7] * self.BLOCK.expansion, + out_channels, + kernel_size=1, + bias=True, + dimension=D) + self.relu = ME.MinkowskiReLU(inplace=True) + + def forward(self, x, return_feats=False): + out = self.conv0p1s1(x) + out = self.bn0(out) + out_p1 = self.relu(out) + + out = self.conv1p1s2(out_p1) + out = self.bn1(out) + out = self.relu(out) + out_b1p2 = self.block1(out) + + out = self.conv2p2s2(out_b1p2) + out = self.bn2(out) + out = self.relu(out) + out_b2p4 = self.block2(out) + + out = self.conv3p4s2(out_b2p4) + out = self.bn3(out) + out = self.relu(out) + out_b3p8 = self.block3(out) + + # tensor_stride=16 + out = self.conv4p8s2(out_b3p8) + out = self.bn4(out) + out = self.relu(out) + out = self.block4(out) + + # tensor_stride=8 + out = self.convtr4p16s2(out) + out = self.bntr4(out) + out = self.relu(out) + + out = ME.cat(out, out_b3p8) + out = self.block5(out) + + # tensor_stride=4 + out = self.convtr5p8s2(out) + out = self.bntr5(out) + out = self.relu(out) + + out = ME.cat(out, out_b2p4) + out = self.block6(out) + + # tensor_stride=2 + out = self.convtr6p4s2(out) + out = self.bntr6(out) + out = self.relu(out) + + out = ME.cat(out, out_b1p2) + out = self.block7(out) + + # tensor_stride=1 + out = self.convtr7p2s2(out) + out = self.bntr7(out) + out = self.relu(out) + + out = ME.cat(out, out_p1) + out = self.block8(out) + + if not return_feats: + return self.final(out) + else: + return out + + +class MinkUNet34(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +class MinkUNet50(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +class MinkUNet101(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) + + +class MinkUNet34A(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 64, 64) + + +class MinkUNet34B(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 64, 32) + + +class MinkUNet34C(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) diff --git a/models/multiheadminkunet.py b/models/multiheadminkunet.py new file mode 100644 index 0000000..366ead3 --- /dev/null +++ b/models/multiheadminkunet.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import MinkowskiEngine as ME + +from models.minkunet import MinkUNet34C + +class Prototypes(nn.Module): + def __init__(self, output_dim, num_prototypes, D=3): + super().__init__() + + self.prototypes = ME.MinkowskiConvolution( + output_dim, + num_prototypes, + kernel_size=1, + bias=False, + dimension=D) + + def forward(self, x): + return self.prototypes(x).F + + +class MultiHead(nn.Module): + def __init__( + self, input_dim, num_prototypes, num_heads + ): + super().__init__() + self.num_heads = num_heads + + # prototypes + self.prototypes = torch.nn.ModuleList( + [Prototypes(input_dim, num_prototypes) for _ in range(num_heads)] + ) + + def forward_head(self, head_idx, feats): + return self.prototypes[head_idx](feats), feats.F + + def forward(self, feats): + out = [self.forward_head(h, feats) for h in range(self.num_heads)] + return [torch.stack(o) for o in map(list, zip(*out))] + + +class MultiHeadMinkUnet(nn.Module): + def __init__( + self, + num_labeled, + num_unlabeled, + overcluster_factor=None, + num_heads=1 + ): + super().__init__() + + # backbone -> pretrained model + identity as final + self.encoder = MinkUNet34C(1, num_labeled) + self.feat_dim = self.encoder.final.in_channels + self.encoder.final = nn.Identity() + + self.head_lab = Prototypes(output_dim=self.feat_dim, + num_prototypes=num_labeled) + if num_heads is not None: + self.head_unlab = MultiHead( + input_dim=self.feat_dim, + num_prototypes=num_unlabeled, + num_heads=num_heads + ) + + if overcluster_factor is not None: + self.head_unlab_over = MultiHead( + input_dim=self.feat_dim, + num_prototypes=num_unlabeled * overcluster_factor, + num_heads=num_heads + ) + + def forward_heads(self, feats): + out = {"logits_lab": self.head_lab(feats)} + if hasattr(self, "head_unlab"): + logits_unlab, proj_feats_unlab = self.head_unlab(feats) + out.update( + { + "logits_unlab": logits_unlab, + "proj_feats_unlab": proj_feats_unlab, + } + ) + if hasattr(self, "head_unlab_over"): + logits_unlab_over, proj_feats_unlab_over = self.head_unlab_over(feats) + out.update( + { + "logits_unlab_over": logits_unlab_over, + "proj_feats_unlab_over": proj_feats_unlab_over, + } + ) + return out + + def forward(self, views): + if isinstance(views, list): + feats = [self.encoder(view) for view in views] + out = [self.forward_heads(f) for f in feats] + out_dict = {"feats": torch.stack(feats)} + for key in out[0].keys(): + out_dict[key] = torch.stack([o[key] for o in out]) + return out_dict + else: + feats = self.encoder(views) + out = self.forward_heads(feats) + out["feats"] = feats.F + return out \ No newline at end of file diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..234edcf --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,132 @@ +# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural +# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part +# of the code. +import torch.nn as nn + +import MinkowskiEngine as ME + + +class ResNetBase(nn.Module): + BLOCK = None + LAYERS = () + INIT_DIM = 64 + PLANES = (64, 128, 256, 512) + + def __init__(self, in_channels, out_channels, D=3): + nn.Module.__init__(self) + self.D = D + assert self.BLOCK is not None + + self.network_initialization(in_channels, out_channels, D) + self.weight_initialization() + + def network_initialization(self, in_channels, out_channels, D): + + self.inplanes = self.INIT_DIM + self.conv1 = nn.Sequential( + ME.MinkowskiConvolution( + in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D + ), + ME.MinkowskiInstanceNorm(self.inplanes), + ME.MinkowskiReLU(inplace=True), + ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D), + ) + + self.layer1 = self._make_layer( + self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2 + ) + self.layer2 = self._make_layer( + self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2 + ) + self.layer3 = self._make_layer( + self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2 + ) + self.layer4 = self._make_layer( + self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2 + ) + + self.conv5 = nn.Sequential( + ME.MinkowskiDropout(), + ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D + ), + ME.MinkowskiInstanceNorm(self.inplanes), + ME.MinkowskiGELU(), + ) + + self.glob_pool = ME.MinkowskiGlobalMaxPooling() + + self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, ME.MinkowskiConvolution): + ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") + + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + ME.MinkowskiConvolution( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + dimension=self.D, + ), + ME.MinkowskiBatchNorm(planes * block.expansion), + ) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride=stride, + dilation=dilation, + downsample=downsample, + dimension=self.D, + ) + ) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x: ME.SparseTensor): + x = self.conv1(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.conv5(x) + x = self.glob_pool(x) + return self.final(x) \ No newline at end of file diff --git a/modules/Discoverer.py b/modules/Discoverer.py new file mode 100644 index 0000000..f739512 --- /dev/null +++ b/modules/Discoverer.py @@ -0,0 +1,766 @@ +import os +import sys +from itertools import chain as chain_iterators + +import MinkowskiEngine as ME +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +import yaml +from scipy.optimize import linear_sum_assignment +from torch import optim +from torch.utils.data import DataLoader +from torchmetrics.functional import jaccard_index +from tqdm import tqdm + +from models.multiheadminkunet import MultiHeadMinkUnet +from utils.collation import ( + collation_fn_restricted_dataset, + collation_fn_restricted_dataset_two_samples, +) +from utils.dataset import dataset_wrapper, get_dataset +from utils.scheduler import LinearWarmupCosineAnnealingLR +from utils.sinkhorn_knopp import SinkhornKnopp + + +class Discoverer(pl.LightningModule): + def __init__(self, label_mapping, label_mapping_inv, unknown_label, **kwargs): + + super().__init__() + self.save_hyperparameters( + {k: v for (k, v) in kwargs.items() if not callable(v)} + ) + + self.model = MultiHeadMinkUnet( + num_labeled=self.hparams.num_labeled_classes, + num_unlabeled=self.hparams.num_unlabeled_classes, + overcluster_factor=self.hparams.overcluster_factor, + num_heads=self.hparams.num_heads + ) + + self.label_mapping = label_mapping + self.label_mapping_inv = label_mapping_inv + self.unknown_label = unknown_label + + if self.hparams.pretrained is not None: + state_dict = torch.load(self.hparams.pretrained) + missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False) + print(f'Missing: {missing_keys}', f'Unexpected: {unexpected_keys}') + + # Sinkorn-Knopp + self.sk = SinkhornKnopp( + num_iters=self.hparams.num_iters_sk, epsilon=self.hparams.initial_epsilon_sk + ) + + self.sk_queue = None + self.sk_indices = [] + + self.loss_per_head = torch.zeros(self.hparams.num_heads, device=self.device) + + # wCE as loss + self.criterion = torch.nn.CrossEntropyLoss(reduction="none") + weights = torch.ones(len(self.label_mapping)) / len(self.label_mapping) + self.criterion.weight = weights + + self.valid_criterion = torch.nn.CrossEntropyLoss() + weights = torch.ones(len(self.label_mapping)) / len(self.label_mapping) + self.valid_criterion.weight = weights + + # Mapping numeric_label -> word_label + dataset_config_file = self.hparams.dataset_config + with open(dataset_config_file, "r") as f: + dataset_config = yaml.safe_load(f) + map_inv = dataset_config["learning_map_inv"] + lab_dict = dataset_config["labels"] + label_dict = {} + for new_label, old_label in map_inv.items(): + label_dict[new_label] = lab_dict[old_label] + self.label_dict = label_dict + + return + + def configure_optimizers(self): + if self.hparams.pretrained is not None: + encoder_params = self.model.encoder.parameters() + rest_params = chain_iterators( + self.model.head_lab.parameters(), self.model.head_unlab.parameters() + ) + if hasattr(self.model, "head_unlab_over"): + rest_params = chain_iterators( + rest_params, self.model.head_unlab_over.parameters() + ) + optimizer = optim.SGD( + [ + {"params": rest_params, "lr": self.hparams.train_lr}, + {"params": encoder_params}, + ], + lr=self.hparams.finetune_lr, + momentum=self.hparams.momentum_for_optim, + weight_decay=self.hparams.weight_decay_for_optim, + ) + else: + optimizer = optim.SGD( + params=self.model.parameters(), + lr=self.hparams.train_lr, + momentum=self.hparams.momentum_for_optim, + weight_decay=self.hparams.weight_decay_for_optim, + ) + + if self.hparams.use_scheduler: + scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=self.hparams.warmup_epochs, + max_epochs=self.hparams.epochs, + warmup_start_lr=self.hparams.min_lr, + eta_min=self.hparams.min_lr, + ) + + return [optimizer], [scheduler] + + return optimizer + + def on_train_start(self): + # Compute/load weights for weighted CE loss + if not os.path.exists("weights.pt"): + dataset = get_dataset(self.hparams.dataset)( + config_file=self.hparams.dataset_config, + split="train", + voxel_size=self.hparams.voxel_size, + downsampling=self.hparams.downsampling, + augment=True, + label_mapping=self.label_mapping, + ) + + weights = torch.zeros((self.hparams.num_classes), device=self.device) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + collate_fn=collation_fn_restricted_dataset, + num_workers=self.hparams.num_workers, + shuffle=False, + ) + + # Split each unknown point across the 5 (or 4) unknown classes + unk_labels_num = self.hparams.num_unlabeled_classes + with tqdm( + total=len(dataloader), + desc="Evaluating weights for wCE", + file=sys.stdout, + ) as pbar: + for _, _, _, _, labels, _ in dataloader: + for label in set(self.label_mapping.values()): + n_points = (labels == label).nonzero().numel() + if label != self.unknown_label: + weights[label] += n_points + else: + weights[-unk_labels_num:] += n_points / unk_labels_num + pbar.update() + + weights += 1 + weights = 1 / weights + weights = weights / torch.sum(weights) + self.criterion.weight = weights + torch.save(weights, "weights.pt") + else: + print("\nLoading weights.pt ...", flush=True) + weights = torch.load("weights.pt").to(self.device) + self.criterion.weight = weights + + def train_dataloader(self): + + dataset = get_dataset(self.hparams.dataset)( + config_file=self.hparams.dataset_config, + split="train", + voxel_size=self.hparams.voxel_size, + downsampling=self.hparams.downsampling, + augment=True, + label_mapping=self.label_mapping, + ) + + dataset = dataset_wrapper(dataset) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + collate_fn=collation_fn_restricted_dataset_two_samples, + num_workers=self.hparams.num_workers, + shuffle=True, + ) + + return dataloader + + def val_dataloader(self): + + dataset = get_dataset(self.hparams.dataset)( + config_file=self.hparams.dataset_config, + split="valid", + voxel_size=self.hparams.voxel_size, + label_mapping=self.label_mapping, + ) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + collate_fn=collation_fn_restricted_dataset, + num_workers=self.hparams.num_workers, + ) + + return dataloader + + def on_train_epoch_start(self): + # Reset best_head tracker + self.loss_per_head = torch.zeros_like(self.loss_per_head, device=self.device) + + # Compute the actual epsilon for Sinkhorn-Knopp + if self.hparams.adapting_epsilon_sk and self.hparams.epochs > 1: + eps_0 = self.hparams.initial_epsilon_sk + eps_n = self.hparams.final_epsilon_sk + n_ep = self.hparams.epochs + act_ep = self.current_epoch + self.sk.epsilon = eps_0 + act_ep * (eps_n - eps_0) / (n_ep - 1) + + def training_step(self, data, _): + def get_uncertainty_mask(preds: torch.Tensor, p=0.5): + """ + returns a boolean mask selecting the p-th percentile of the predictions with highest confidence for each class + + :param preds: Tensor of predicted logits (N x Nc) + :param p: float describing the percentile to use in the selection + """ + + self.log(f"utils/tot_p", preds.shape[0]) + + # init mask + uncertainty_mask = torch.zeros( + preds.shape[0], dtype=torch.bool, device=self.device + ) + + # get hard predictions + hard_preds = preds.argmax(dim=-1) + + # generate indexes for consistent mapping + indexes = torch.arange(preds.shape[0], device=self.device) + + # for each novel class + for un_tmp in range(self.hparams.num_unlabeled_classes): + # select points with given novel class + un_idx_tmp = hard_preds == un_tmp + + if (un_idx_tmp.sum() * p).int() > 0: + # select confident novel pts + un_conf = preds[un_idx_tmp].softmax(-1)[:, un_tmp] + un_sel_tmp = indexes[un_idx_tmp] + + # sort them + sorted_conf_tmp, sorted_idx_tmp = torch.sort(un_conf) + un_conf = un_conf[sorted_idx_tmp] + un_sel_tmp = un_sel_tmp[sorted_idx_tmp] + + # get percentile idx + perc_tmp = (un_idx_tmp.sum() * p).int() + + # update th + un_th_tmp = un_conf[perc_tmp] + + # find valid pts + mask_tmp = un_conf > un_th_tmp + + self.log(f"utils/thr_{un_tmp}", un_th_tmp) + self.log( + f"utils/perc_{un_tmp}", mask_tmp.sum() / un_sel_tmp.shape[0] + ) + self.log(f"utils/tot_p_{un_tmp}", un_sel_tmp.shape[0]) + + uncertainty_mask[un_sel_tmp[mask_tmp]] = 1 + + return uncertainty_mask + + nlc = self.hparams.num_labeled_classes + + ( + coords, + feats, + _, + selected_idx, + mapped_labels, + coords1, + feats1, + _, + selected_idx1, + mapped_labels1, + pcd_indexes, + ) = data + + pcd_masks = [] + pcd_masks1 = [] + for i in range(pcd_indexes.shape[0]): + pcd_masks.append(coords[:, 0] == i) + pcd_masks1.append(coords1[:, 0] == i) + + # Forward + coords = coords.int() + coords1 = coords1.int() + + sp_tensor = ME.SparseTensor(features=feats.float(), coordinates=coords) + sp_tensor1 = ME.SparseTensor(features=feats1.float(), coordinates=coords1) + + # Clear cache at regular interval + if self.global_step % self.hparams.clear_cache_int == 0: + torch.cuda.empty_cache() + + out = self.model(sp_tensor) + out1 = self.model(sp_tensor1) + + # Gather outputs + out["logits_lab"] = ( + out["logits_lab"].unsqueeze(0).expand(self.hparams.num_heads, -1, -1) + ) + out1["logits_lab"] = ( + out1["logits_lab"].unsqueeze(0).expand(self.hparams.num_heads, -1, -1) + ) + logits = torch.cat([out["logits_lab"], out["logits_unlab"]], dim=-1) + logits1 = torch.cat([out1["logits_lab"], out1["logits_unlab"]], dim=-1) + if self.hparams.overcluster_factor is not None: + logits_over = torch.cat( + [out["logits_lab"], out["logits_unlab_over"]], dim=-1 + ) + logits_over1 = torch.cat( + [out1["logits_lab"], out1["logits_unlab_over"]], dim=-1 + ) + + mask_lab = mapped_labels != self.unknown_label + mask_lab1 = mapped_labels1 != self.unknown_label + + # Generate one-hot targets for the base points + targets_lab = ( + F.one_hot( + mapped_labels[mask_lab].to(torch.long), + num_classes=self.hparams.num_labeled_classes, + ) + .float() + .to(self.device) + ) + targets_lab1 = ( + F.one_hot( + mapped_labels1[mask_lab1].to(torch.long), + num_classes=self.hparams.num_labeled_classes, + ) + .float() + .to(self.device) + ) + + # Generate empty targets for all the points + targets = torch.zeros_like(logits) + targets1 = torch.zeros_like(logits1) + if self.hparams.overcluster_factor is not None: + targets_over = torch.zeros_like(logits_over) + targets_over1 = torch.zeros_like(logits_over1) + + # Generate pseudo-labels with sinkhorn-knopp and fill unlab targets + act_queue = ( + None + if self.current_epoch < self.hparams.queue_start_epoch + else self.sk_queue + ) + for h in range(self.hparams.num_heads): + # Insert the one-hot labels + targets[h, mask_lab, :nlc] = targets_lab.type_as(targets) + targets1[h, mask_lab1, :nlc] = targets_lab1.type_as(targets1) + + if self.hparams.use_uncertainty_queue or self.hparams.use_uncertainty_loss: + # Get masks for certain points + unc_mask = get_uncertainty_mask( + out["logits_unlab"][h][~mask_lab].detach(), + p=self.hparams.uncertainty_percentile, + ) + unc_mask1 = get_uncertainty_mask( + out1["logits_unlab"][h][~mask_lab1].detach(), + p=self.hparams.uncertainty_percentile, + ) + if h == 0: + unc_mask_overall = unc_mask + unc_mask_overall1 = unc_mask1 + else: + unc_mask_overall = torch.logical_and(unc_mask_overall, unc_mask) + unc_mask_overall1 = torch.logical_and(unc_mask_overall1, unc_mask1) + + if self.hparams.use_uncertainty_loss: + # Get predictions from Sinkhorn only for high-confidence points + pred_sk = self.sk( + out["feats"][~mask_lab][unc_mask], + self.model.head_unlab.prototypes[h].prototypes.kernel.data, + queue=act_queue, + ).type_as(targets) + pred_sk1 = self.sk( + out1["feats"][~mask_lab1][unc_mask1], + self.model.head_unlab.prototypes[h].prototypes.kernel.data, + queue=act_queue, + ).type_as(targets) + + new_mask_unlab = ~mask_lab.clone() + new_mask_unlab[new_mask_unlab == True] = unc_mask + new_mask_unlab1 = ~mask_lab1.clone() + new_mask_unlab1[new_mask_unlab1 == True] = unc_mask1 + # Use sinkhorn labels only with the confident points (unconfident ones remain zero_labelled) + targets[h, new_mask_unlab, nlc:] = pred_sk + targets1[h, new_mask_unlab1, nlc:] = pred_sk1 + else: + # Insert sinkhorn labels + targets[h, ~mask_lab, nlc:] = self.sk( + out["feats"][~mask_lab], + self.model.head_unlab.prototypes[h].prototypes.kernel.data, + queue=act_queue, + ).type_as(targets) + targets1[h, ~mask_lab1, nlc:] = self.sk( + out1["feats"][~mask_lab1], + self.model.head_unlab.prototypes[h].prototypes.kernel.data, + queue=act_queue, + ).type_as(targets) + + if self.hparams.overcluster_factor is not None: + # Manage also overclustering heads + targets_over[h, mask_lab, :nlc] = targets_lab.type_as(targets) + targets_over[h, ~mask_lab, nlc:] = self.sk( + out["feats"][~mask_lab], + self.model.head_unlab_over.prototypes[h].prototypes.kernel.data, + queue=act_queue, + ).type_as(targets) + targets_over1[h, mask_lab1, :nlc] = targets_lab1.type_as(targets1) + targets_over1[h, ~mask_lab1, nlc:] = self.sk( + out1["feats"][~mask_lab1], + self.model.head_unlab_over.prototypes[h].prototypes.kernel.data, + queue=act_queue, + ).type_as(targets1) + + # Evaluate loss + loss_cluster = self.loss( + logits, targets1, selected_idx, selected_idx1, pcd_masks, pcd_masks1 + ) + loss_cluster += self.loss( + logits1, targets, selected_idx1, selected_idx, pcd_masks1, pcd_masks + ) + + if self.hparams.overcluster_factor is not None: + loss_overcluster = self.loss( + logits_over, + targets_over1, + selected_idx, + selected_idx1, + pcd_masks, + pcd_masks1, + ) + loss_overcluster += self.loss( + logits_over1, + targets_over, + selected_idx1, + selected_idx, + pcd_masks1, + pcd_masks, + ) + else: + loss_overcluster = loss_cluster + + # Keep track of the loss for each head + self.loss_per_head += loss_cluster.clone().detach() + + loss_cluster = loss_cluster.mean() + loss_overcluster = loss_overcluster.mean() + loss = (loss_cluster + loss_overcluster) / 2 + + # logging + results = { + "train/loss": loss.detach(), + "train/loss_cluster": loss_cluster.detach(), + } + + if self.hparams.overcluster_factor is not None: + results["train/loss_overcluster"] = loss_overcluster.detach() + + self.log_dict(results, on_step=True, on_epoch=True, sync_dist=True) + + if self.hparams.queue_start_epoch != -1: + if self.hparams.use_uncertainty_queue: + self.update_queue( + torch.cat( + ( + out["feats"][~mask_lab][unc_mask_overall], + out1["feats"][~mask_lab1][unc_mask_overall1], + ) + ) + ) + else: + self.update_queue( + torch.cat((out["feats"][~mask_lab], out1["feats"][~mask_lab1])) + ) + + return loss + + def update_queue(self, feats: torch.Tensor): + """ + Updates self.queue with the features of the novel points in the current batch + + :param feats: the features for the novel points in the current batch + """ + feats = feats.detach() + if not self.hparams.use_uncertainty_queue: + n_feats_to_retain = int(feats.shape[0] * self.hparams.queue_percentage) + mask = torch.randperm(feats.shape[0])[:n_feats_to_retain] + else: + n_feats_to_retain = feats.shape[0] + mask = torch.ones(n_feats_to_retain, device=feats.device, dtype=torch.bool) + if self.sk_queue is None: + self.sk_queue = feats[mask] + self.sk_indices.append(n_feats_to_retain) + return + + if len(self.sk_indices) < self.hparams.queue_batches: + self.sk_queue = torch.vstack((feats[mask], self.sk_queue)) + self.sk_indices.insert(0, n_feats_to_retain) + else: + self.sk_queue = torch.vstack( + (feats[mask], self.sk_queue[: -self.sk_indices[-1]]) + ) + self.sk_indices.insert(0, n_feats_to_retain) + del self.sk_indices[-1] + + def loss( + self, + logits: torch.Tensor, + targets: torch.Tensor, + idx_logits: torch.Tensor, + idx_targets: torch.Tensor, + pcd_mask_logits: torch.Tensor, + pcd_mask_targets: torch.Tensor, + ): + """ + Evaluates the loss function of the predicted logits w.r.t. the targets + + :param logits: predicted logits for the first augmentation of the point clouds + :param targets: targets for the second augmentation of the point clouds + :param idx_logits: indexes of the selected points in the first augmentation of the point clouds + :param idx_targets: indexes of the selected points in the second augmentation of the point clouds + :param pcd_mask_logits: mask to separate the different point clouds in the batch + :param pcd_mask_targets: mask to separate the different point clouds in the batch + """ + + if self.criterion.weight.shape[0] != targets.shape[2]: + weight_bck = self.criterion.weight.clone() + weight_new = torch.zeros(targets.shape[2], device=self.device) + weight_new[: self.hparams.num_labeled_classes] = weight_bck[ + : self.hparams.num_labeled_classes + ] + new_weight_tmp = weight_bck[-1] / self.hparams.overcluster_factor + weight_new[ + -self.hparams.num_unlabeled_classes * self.hparams.overcluster_factor : + ] = new_weight_tmp + self.criterion.weight = weight_new + else: + weight_bck = None + + heads_loss = None + + for head in range(self.hparams.num_heads): + head_loss = None + for pcd in range(len(pcd_mask_logits)): + pcd_logits = logits[head][pcd_mask_logits[pcd]] + pcd_targets = targets[head][pcd_mask_targets[pcd]] + #### + logit_shape = pcd_logits.shape[0] + target_shape = pcd_targets.shape[0] + #### + mask_logits = torch.isin( + idx_logits[pcd_mask_logits[pcd]], idx_targets[pcd_mask_targets[pcd]] + ) + mask_targets = torch.isin( + idx_targets[pcd_mask_targets[pcd]], idx_logits[pcd_mask_logits[pcd]] + ) + pcd_logits = pcd_logits[mask_logits] + pcd_targets = pcd_targets[mask_targets] + #### + perc_to_log = ( + pcd_logits.shape[0] / logit_shape + + pcd_targets.shape[0] / target_shape + ) / 2 + # print(perc_to_log) + self.log("utils/points_in_common", perc_to_log) + #### + + loss = self.criterion(pcd_logits, pcd_targets) + if self.hparams.use_uncertainty_loss: + loss = loss[loss > 0] + # pre-compute data for wCE + multiplier = 1 / ((self.criterion.weight * pcd_targets).sum(1)).sum(0) + loss *= multiplier + loss = loss.sum() + if head_loss is None: + head_loss = loss + else: + head_loss = torch.hstack((head_loss, loss)) + + if heads_loss is None: + heads_loss = head_loss.mean() + else: + heads_loss = torch.hstack((heads_loss, head_loss.mean())) + + if weight_bck is not None: + self.criterion.weight = weight_bck + + return heads_loss + + def on_validation_epoch_start(self): + # Run the hungarian algorithm to map each novel class to the related semantic class + if ( + self.hparams.hungarian_at_each_step + or len(self.label_mapping_inv) < self.hparams.num_classes + ): + cost_matrix = torch.zeros( + ( + self.hparams.num_unlabeled_classes, + self.hparams.num_unlabeled_classes, + ), + device=self.device, + ) + + dataset = get_dataset(self.hparams.dataset)( + config_file=self.hparams.dataset_config, + split="valid", + voxel_size=self.hparams.voxel_size, + label_mapping=self.label_mapping, + ) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + collate_fn=collation_fn_restricted_dataset, + num_workers=self.hparams.num_workers, + ) + + real_labels_to_be_matched = [ + label + for label in self.label_mapping + if self.label_mapping[label] == self.unknown_label + ] + + with tqdm( + total=len(dataloader), desc="Cost matrix build-up", file=sys.stdout + ) as pbar: + for step, data in enumerate(dataloader): + coords, feats, real_labels, _, mapped_labels, _ = data + + # Forward + coords = coords.int().to(self.device) + feats = feats.to(self.device) + real_labels = real_labels.to(self.device) + + sp_tensor = ME.SparseTensor( + features=feats.float(), coordinates=coords + ) + + # Must clear cache at regular interval + if self.global_step % self.hparams.clear_cache_int == 0: + torch.cuda.empty_cache() + + out = self.model(sp_tensor) + + best_head = torch.argmin(self.loss_per_head) + + mask_unknown = mapped_labels == self.unknown_label + + preds = out["logits_unlab"][best_head] + preds = torch.argmax(preds[mask_unknown].softmax(1), dim=1) + + for pseudo_label in range(self.hparams.num_unlabeled_classes): + mask_pseudo = preds == pseudo_label + for j, real_label in enumerate(real_labels_to_be_matched): + mask_real = real_labels[mask_unknown] == real_label + cost_matrix[pseudo_label, j] += torch.logical_and( + mask_pseudo, mask_real + ).sum() + + pbar.update() + + cost_matrix = cost_matrix / ( + torch.negative(cost_matrix) + + torch.sum(cost_matrix, dim=0) + + torch.sum(cost_matrix, dim=1).unsqueeze(1) + + 1e-5 + ) + + # Hungarian + cost_matrix = cost_matrix.cpu() + row_ind, col_ind = linear_sum_assignment( + cost_matrix=cost_matrix, maximize=True + ) + label_mapping = { + row_ind[i] + self.unknown_label: real_labels_to_be_matched[col_ind[i]] + for i in range(len(row_ind)) + } + self.label_mapping_inv.update(label_mapping) + + # Reorder weights for validation loss + weights = self.criterion.weight.clone() + sorted_label_mapping_inv = dict( + sorted(self.label_mapping_inv.items(), key=lambda item: item[1]) + ) + sorter = list(sorted_label_mapping_inv.keys()) + self.valid_criterion.weight = weights[sorter] + + return + + def validation_step(self, data, _): + coords, feats, real_labels, _, _, _ = data + + # Forward + coords = coords.int() + + sp_tensor = ME.SparseTensor(features=feats.float(), coordinates=coords) + + # Must clear cache at regular interval + if self.global_step % self.hparams.clear_cache_int == 0: + torch.cuda.empty_cache() + + out = self.model(sp_tensor) + + best_head = torch.argmin(self.loss_per_head) + + preds = torch.cat([out["logits_lab"], out["logits_unlab"][best_head]], dim=-1) + + sorted_label_mapping_inv = dict( + sorted(self.label_mapping_inv.items(), key=lambda item: item[1]) + ) + sorter = list(sorted_label_mapping_inv.keys()) + + preds = preds[:, sorter] + + loss = self.valid_criterion(preds, real_labels.long()) + + gt_labels = real_labels + avail_labels = torch.unique(gt_labels).long() + _, pred_labels = torch.max(torch.softmax(preds.detach(), dim=1), dim=1) + IoU = jaccard_index(gt_labels, pred_labels, reduction="none") + IoU = IoU[avail_labels] + + self.log("valid/loss", loss, on_epoch=True, sync_dist=True, rank_zero_only=True) + IoU_to_log = { + f"valid/IoU/{self.label_dict[int(avail_labels[i])]}": label_IoU + for i, label_IoU in enumerate(IoU) + } + for label, value in IoU_to_log.items(): + self.log(label, value, on_epoch=True, sync_dist=True, rank_zero_only=True) + + return loss + + def on_save_checkpoint(self, checkpoint): + # Maintain info about best head when saving checkpoints + checkpoint["loss_per_head"] = self.loss_per_head + return super().on_save_checkpoint(checkpoint) + + def on_load_checkpoint(self, checkpoint): + self.loss_per_head = checkpoint.get( + "loss_per_head", + torch.zeros( + checkpoint["hyper_parameters"]["num_heads"], device=self.device + ), + ) + return super().on_load_checkpoint(checkpoint) diff --git a/modules/Pretrainer.py b/modules/Pretrainer.py new file mode 100644 index 0000000..6b02e65 --- /dev/null +++ b/modules/Pretrainer.py @@ -0,0 +1,239 @@ +import os +import sys + +import MinkowskiEngine as ME +import pytorch_lightning as pl +import torch +import yaml +from torch import optim +from torch.utils.data import DataLoader +from torchmetrics.functional import jaccard_index +from tqdm import tqdm + +from models.multiheadminkunet import MultiHeadMinkUnet +from utils.collation import collation_fn_restricted_dataset +from utils.dataset import get_dataset + + +class Pretrainer(pl.LightningModule): + def __init__(self, label_mapping, label_mapping_inv, unknown_label, **kwargs): + + super().__init__() + self.save_hyperparameters( + {k: v for (k, v) in kwargs.items() if not callable(v)} + ) + + self.model = MultiHeadMinkUnet( + num_labeled=self.hparams.num_labeled_classes, + num_unlabeled=self.hparams.num_unlabeled_classes, + num_heads=None, + ) + + self.label_mapping = label_mapping + self.label_mapping_inv = label_mapping_inv + self.unknown_label = unknown_label + + # wCE as loss + self.criterion = torch.nn.CrossEntropyLoss() + weights = ( + torch.ones(self.hparams.num_labeled_classes) + / self.hparams.num_labeled_classes + ) + self.criterion.weight = weights + + # Mapping numeric_label -> word_label + dataset_config_file = self.hparams.dataset_config + with open(dataset_config_file, "r") as f: + dataset_config = yaml.safe_load(f) + map_inv = dataset_config["learning_map_inv"] + lab_dict = dataset_config["labels"] + label_dict = {} + for new_label, old_label in map_inv.items(): + label_dict[new_label] = lab_dict[old_label] + self.label_dict = label_dict + + return + + def configure_optimizers(self): + optimizer = optim.SGD( + params=self.model.parameters(), + lr=self.hparams.train_lr, + momentum=self.hparams.momentum_for_optim, + weight_decay=self.hparams.weight_decay_for_optim, + ) + + return optimizer + + def on_train_start(self): + # Compute/load weights for weighted CE loss + if not os.path.exists("pret_weights.pt"): + dataset = get_dataset(self.hparams.dataset)( + config_file=self.hparams.dataset_config, + split="train", + voxel_size=self.hparams.voxel_size, + label_mapping=self.label_mapping, + ) + + weights = torch.zeros( + (self.hparams.num_labeled_classes), device=self.device + ) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + collate_fn=collation_fn_restricted_dataset, + num_workers=self.hparams.num_workers, + shuffle=False, + ) + + with tqdm( + total=len(dataloader), + desc="Evaluating weights for wCE", + file=sys.stdout, + ) as pbar: + for _, _, _, _, labels, _ in dataloader: + for label in set(self.label_mapping.values()): + n_points = (labels == label).nonzero().numel() + if label != self.unknown_label: + weights[label] += n_points + pbar.update() + + weights += 1 + weights = 1 / weights + weights = weights / torch.sum(weights) + self.criterion.weight = weights + torch.save(weights, "pret_weights.pt") + else: + print("Loading pret_weights.pt ...", flush=True) + weights = torch.load("pret_weights.pt").to(self.device) + self.criterion.weight = weights + + def train_dataloader(self): + + dataset = get_dataset(self.hparams.dataset)( + config_file=self.hparams.dataset_config, + split="train", + voxel_size=self.hparams.voxel_size, + downsampling=self.hparams.downsampling, + augment=True, + label_mapping=self.label_mapping, + ) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + collate_fn=collation_fn_restricted_dataset, + num_workers=self.hparams.num_workers, + shuffle=True, + ) + + return dataloader + + def val_dataloader(self): + + dataset = get_dataset(self.hparams.dataset)( + config_file=self.hparams.dataset_config, + split="valid", + voxel_size=self.hparams.voxel_size, + label_mapping=self.label_mapping, + ) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.hparams.batch_size, + collate_fn=collation_fn_restricted_dataset, + num_workers=self.hparams.num_workers, + ) + + return dataloader + + def training_step(self, data, _): + coords, feats, real_labels, _, mapped_labels, _ = data + + # Forward + coords = coords.int() + + sp_tensor = ME.SparseTensor(features=feats.float(), coordinates=coords) + + # Must clear cache at regular interval + if self.global_step % self.hparams.clear_cache_int == 0: + torch.cuda.empty_cache() + + out = self.model(sp_tensor) + + mask_lab = mapped_labels != self.unknown_label + + preds = out["logits_lab"] + preds = preds[mask_lab] + + loss = self.criterion(preds, mapped_labels[mask_lab].long()) + + gt_labels = real_labels[mask_lab] + avail_labels = torch.unique(gt_labels).long() + pred_labels = torch.argmax(torch.softmax(preds.detach(), dim=1), dim=1) + # Transform predictions + for key, value in self.label_mapping_inv.items(): + pred_labels[pred_labels == key] = -value + pred_labels = -pred_labels + + IoU = jaccard_index(gt_labels, pred_labels, reduction="none") + IoU = IoU[avail_labels] + + # logging + results = { + "train/loss": loss.detach(), + } + + self.log_dict(results, on_step=True, on_epoch=True, sync_dist=True) + IoU_to_log = { + f"train/IoU/{self.label_dict[int(avail_labels[i])]}": label_IoU + for i, label_IoU in enumerate(IoU) + } + self.log_dict(IoU_to_log, on_step=False, on_epoch=True, sync_dist=True) + + return loss + + def validation_step(self, data, _): + coords, feats, real_labels, _, mapped_labels, _ = data + + # Forward + coords = coords.int() + + sp_tensor = ME.SparseTensor(features=feats.float(), coordinates=coords) + + # Must clear cache at regular interval + if self.global_step % self.hparams.clear_cache_int == 0: + torch.cuda.empty_cache() + + out = self.model(sp_tensor) + + mask_lab = mapped_labels != self.unknown_label + + preds = out["logits_lab"] + preds = preds[mask_lab] + + loss = self.criterion(preds, mapped_labels[mask_lab].long()) + + gt_labels = real_labels[mask_lab] + avail_labels = torch.unique(gt_labels).long() + pred_labels = torch.argmax(torch.softmax(preds.detach(), dim=1), dim=1) + # Transform predictions + for key, value in self.label_mapping_inv.items(): + pred_labels[pred_labels == key] = -value + pred_labels = -pred_labels + + IoU = jaccard_index(gt_labels, pred_labels, reduction="none") + IoU = IoU[avail_labels] + + # logging + results = { + "valid/loss": loss.detach(), + } + self.log_dict(results, on_step=True, on_epoch=True, sync_dist=True) + IoU_to_log = { + f"valid/IoU/{self.label_dict[int(avail_labels[i])]}": label_IoU + for i, label_IoU in enumerate(IoU) + } + self.log_dict(IoU_to_log, on_step=False, on_epoch=True, sync_dist=True) + + return diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 0000000..b2e23c3 --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,58 @@ +import os +import torch +import re +from pytorch_lightning import Callback + +class mIoUEvaluatorCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + interesting_metric_regex = re.compile(r'train/IoU/[\S]+_epoch') + IoU_list = [] + callback_metrics = trainer.callback_metrics + for key in callback_metrics.keys(): + mo = interesting_metric_regex.search(key) + if mo is not None: + IoU_list.append(callback_metrics[key]) + if IoU_list: + mIoU = torch.mean(torch.stack(IoU_list)) + pl_module.log('train/mIoU', mIoU, rank_zero_only=True) + + def on_validation_epoch_end(self, trainer, pl_module): + interesting_metric_regex = re.compile(r'valid/IoU/[\S]+') + IoU_list = [] + callback_metrics = trainer.callback_metrics + for key in callback_metrics.keys(): + mo = interesting_metric_regex.search(key) + if mo is not None: + IoU_list.append(callback_metrics[key]) + if IoU_list: + mIoU = torch.mean(torch.stack(IoU_list)) + pl_module.log('valid/mIoU', mIoU, rank_zero_only=True) + + def on_test_epoch_end(self, trainer, pl_module): + interesting_metric_regex = re.compile(r'test/IoU/[\S]+') + IoU_list = [] + callback_metrics = trainer.callback_metrics + for key in callback_metrics.keys(): + mo = interesting_metric_regex.search(key) + if mo is not None: + IoU_list.append(callback_metrics[key]) + if IoU_list: + mIoU = torch.mean(torch.stack(IoU_list)) + pl_module.log('test/mIoU', mIoU, rank_zero_only=True) + +class PretrainCheckpointCallback(Callback): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + checkpoint_filename = ( + "-".join( + [ + "pretrain", + str(pl_module.hparams.split), + pl_module.hparams.dataset, + pl_module.hparams.comment, + ] + ) + + ".ckpt" + ) + checkpoint_path = os.path.join( + pl_module.hparams.checkpoint_dir, checkpoint_filename) + torch.save(pl_module.model.state_dict(), checkpoint_path) \ No newline at end of file diff --git a/utils/collation.py b/utils/collation.py new file mode 100644 index 0000000..3771948 --- /dev/null +++ b/utils/collation.py @@ -0,0 +1,52 @@ +import MinkowskiEngine as ME +import torch +import numpy as np + +def collation_fn_dataset(data_labels): + coords, feats, labels, selected_idx, pcd_indexes = list(zip(*data_labels)) + + # Create batched coordinates for the SparseTensor input + bcoords = ME.utils.batched_coordinates(coords) + + # Concatenate all lists + feats = torch.from_numpy(np.concatenate(feats, 0)).float() + labels = torch.from_numpy(np.concatenate(labels, 0)).int() + selected_idx = torch.from_numpy(np.concatenate(selected_idx, 0)).long() + pcd_indexes = torch.tensor(pcd_indexes, dtype=torch.int16) + + return bcoords, feats, labels, selected_idx, pcd_indexes + +def collation_fn_restricted_dataset(data_labels): + coords, feats, labels, selected_idx, mapped_labels, pcd_indexes = list(zip(*data_labels)) + + # Create batched coordinates for the SparseTensor input + bcoords = ME.utils.batched_coordinates(coords) + + # Concatenate all lists + feats = torch.from_numpy(np.concatenate(feats, 0)).float() + labels = torch.from_numpy(np.concatenate(labels, 0)).int() + selected_idx = torch.from_numpy(np.concatenate(selected_idx, 0)).long() + mapped_labels = torch.from_numpy(np.concatenate(mapped_labels, 0)).int() + pcd_indexes = torch.tensor(pcd_indexes, dtype=torch.int16) + + return bcoords, feats, labels, selected_idx, mapped_labels, pcd_indexes + +def collation_fn_restricted_dataset_two_samples(data_labels): + coords, feats, labels, selected_idx, mapped_labels, coords1, feats1, labels1, selected_idx1, mapped_labels1, pcd_indexes = list(zip(*data_labels)) + + # Create batched coordinates for the SparseTensor input + bcoords = ME.utils.batched_coordinates(coords) + bcoords1 = ME.utils.batched_coordinates(coords1) + + # Concatenate all lists + feats = torch.from_numpy(np.concatenate(feats, 0)).float() + labels = torch.from_numpy(np.concatenate(labels, 0)).int() + selected_idx = torch.from_numpy(np.concatenate(selected_idx, 0)).long() + mapped_labels = torch.from_numpy(np.concatenate(mapped_labels, 0)).int() + feats1 = torch.from_numpy(np.concatenate(feats1, 0)).float() + labels1 = torch.from_numpy(np.concatenate(labels1, 0)).int() + selected_idx1 = torch.from_numpy(np.concatenate(selected_idx1, 0)).long() + mapped_labels1 = torch.from_numpy(np.concatenate(mapped_labels1, 0)).int() + pcd_indexes = torch.tensor(pcd_indexes, dtype=torch.int16) + + return bcoords, feats, labels, selected_idx, mapped_labels, bcoords1, feats1, labels1, selected_idx1, mapped_labels1, pcd_indexes \ No newline at end of file diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..4b515be --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,651 @@ +from copy import deepcopy +import os +import random + +import MinkowskiEngine as ME +import numpy as np +import torch +import yaml + +from utils.voxelizer import Voxelizer + + +def get_dataset(name): + if name == "SemanticKITTI": + return SemanticKITTIRestrictedDataset + elif name == "SemanticPOSS": + return SemanticPOSSRestrictedDataset + else: + raise NameError(f'Dataset "{name}" not yet implemented') + + +class SemanticKITTIDataset(torch.utils.data.Dataset): + def __init__( + self, + config_file="config/dataset.yaml", + split="train", + voxel_size=0.05, + downsampling=80000, + augment=False, + ): + """Load data from given dataset directory.""" + + with open(config_file, "r") as f: + self.config = yaml.safe_load(f) + + self.files = {"input": []} + if split != "test": + self.files["label"] = [] + self.filenames = [] + + self.voxel_size = voxel_size + self.downsampling = downsampling + self.augment = False + if split == "train" and augment: + self.augment = True + self.clip_bounds = None + self.scale_augmentation_bound = (0.95, 1.05) + self.rotation_augmentation_bound = ( + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + ) + self.translation_augmentation_ratio_bound = None + + self.voxelizer = Voxelizer( + voxel_size=self.voxel_size, + clip_bound=self.clip_bounds, + use_augmentation=self.augment, + scale_augmentation_bound=self.scale_augmentation_bound, + rotation_augmentation_bound=self.rotation_augmentation_bound, + translation_augmentation_ratio_bound=self.translation_augmentation_ratio_bound, + ignore_label=-1, + ) + + for sequence in self.config["split_sequence"][split]: + for idx, type in enumerate(self.files.keys()): + files_path = os.path.join( + self.config["dataset_path"], + "sequences", + sequence, + self.config["folder_name"][type], + ) + if not os.path.exists(files_path): + raise RuntimeError("Point cloud directory missing: " + files_path) + files = os.listdir(files_path) + data = sorted([os.path.join(files_path, f) for f in files]) + if len(data) == 0: + raise RuntimeError("Missing data for " + type) + self.files[type].extend(data) + if idx == 0: + self.filenames.extend(data) + + self.num_files = len(self.filenames) + + def __len__(self): + return self.num_files + + def __getitem__(self, t): + pc_filename = self.files["input"][t] + scan = np.fromfile(pc_filename, dtype=np.float32) + scan = scan.reshape((-1, 4)) + coordinates = scan[:, 0:3] # get xyz + remissions = scan[:, 3] # get remission + + features = np.ones((coordinates.shape[0], 1)) + + # AUGMENTATION + if self.augment: + # DOWNSAMPLING + selected_idx = np.random.choice( + coordinates.shape[0], self.downsampling, replace=False + ) + coordinates = coordinates[selected_idx] + features = features[selected_idx] + + # TRANSFORMATIONS + voxel_mtx, affine_mtx = self.voxelizer.get_transformation_matrix() + + rigid_transformation = affine_mtx @ voxel_mtx + # Apply transformations + + homo_coords = np.hstack( + ( + coordinates, + np.ones((coordinates.shape[0], 1), dtype=coordinates.dtype), + ) + ) + # coords = np.floor(homo_coords @ rigid_transformation.T[:, :3]) + coordinates = homo_coords @ rigid_transformation.T[:, :3] + else: + selected_idx = np.arange(coordinates.shape[0]) + + if "label" in self.files.keys(): + label_filename = self.files["label"][t] + labels = np.fromfile(label_filename, dtype=np.int32) + labels = labels.reshape((-1)) + labels = labels & 0xFFFF + if self.augment: + labels = labels[selected_idx] + for index, element in enumerate(labels): + labels[index] = self.config["learning_map"].get(element, -1) + else: + labels = np.negative(np.ones(coordinates.shape[0])) + + # REMOVE UNLABELED POINTS IF NOT IN TESTING + if "label" in self.files.keys(): + labelled_idx = labels != -1 + coordinates = coordinates[labelled_idx] + features = features[labelled_idx] + labels = labels[labelled_idx] + selected_idx = selected_idx[labelled_idx] + + discrete_coords, unique_map = ME.utils.sparse_quantize( + coordinates=coordinates, + return_index=True, + quantization_size=self.voxel_size, + ) + + unique_feats = features[unique_map] + unique_labels = labels[unique_map] + selected_idx = selected_idx[unique_map] + + return discrete_coords, unique_feats, unique_labels, selected_idx, t + + +class SemanticKITTIRestrictedDataset(torch.utils.data.Dataset): + def __init__( + self, + config_file="config/dataset.yaml", + split="train", + voxel_size=0.05, + downsampling=80000, + augment=False, + label_mapping=None, + ): + """Load data from given dataset directory.""" + + with open(config_file, "r") as f: + self.config = yaml.safe_load(f) + + self.files = {"input": []} + if split != "test": + self.files["label"] = [] + self.filenames = [] + + self.voxel_size = voxel_size + self.downsampling = downsampling + self.augment = False + if split == "train" and augment: + self.augment = True + self.clip_bounds = None + self.scale_augmentation_bound = (0.95, 1.05) + self.rotation_augmentation_bound = ( + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + ) + self.translation_augmentation_ratio_bound = None + + self.voxelizer = Voxelizer( + voxel_size=self.voxel_size, + clip_bound=self.clip_bounds, + use_augmentation=self.augment, + scale_augmentation_bound=self.scale_augmentation_bound, + rotation_augmentation_bound=self.rotation_augmentation_bound, + translation_augmentation_ratio_bound=self.translation_augmentation_ratio_bound, + ignore_label=-1, + ) + + for sequence in self.config["split_sequence"][split]: + for idx, type in enumerate(self.files.keys()): + files_path = os.path.join( + self.config["dataset_path"], + "sequences", + sequence, + self.config["folder_name"][type], + ) + if not os.path.exists(files_path): + raise RuntimeError("Point cloud directory missing: " + files_path) + files = os.listdir(files_path) + data = sorted([os.path.join(files_path, f) for f in files]) + if len(data) == 0: + raise RuntimeError("Missing data for " + type) + self.files[type].extend(data) + if idx == 0: + self.filenames.extend(data) + + self.num_files = len(self.filenames) + + if label_mapping is not None: + self.label_mapping_function = np.vectorize(lambda x: label_mapping[x]) + else: + self.label_mapping_function = None + + def __len__(self): + return self.num_files + + def __getitem__(self, t): + pc_filename = self.files["input"][t] + scan = np.fromfile(pc_filename, dtype=np.float32) + scan = scan.reshape((-1, 4)) + coordinates = scan[:, 0:3] # get xyz + remissions = scan[:, 3] # get remission + + features = np.ones((coordinates.shape[0], 1)) + + # AUGMENTATION + if self.augment: + # DOWNSAMPLING + selected_idx = np.random.choice( + coordinates.shape[0], self.downsampling, replace=False + ) + selected_idx = np.sort(selected_idx) + coordinates = coordinates[selected_idx] + features = features[selected_idx] + + # TRANSFORMATIONS + voxel_mtx, affine_mtx = self.voxelizer.get_transformation_matrix() + + rigid_transformation = affine_mtx @ voxel_mtx + # Apply transformations + + homo_coords = np.hstack( + ( + coordinates, + np.ones((coordinates.shape[0], 1), dtype=coordinates.dtype), + ) + ) + # coords = np.floor(homo_coords @ rigid_transformation.T[:, :3]) + coordinates = homo_coords @ rigid_transformation.T[:, :3] + else: + selected_idx = np.arange(coordinates.shape[0]) + + if "label" in self.files.keys(): + label_filename = self.files["label"][t] + labels = np.fromfile(label_filename, dtype=np.int32) + labels = labels.reshape((-1)) + labels = labels & 0xFFFF + if self.augment: + labels = labels[selected_idx] + for index, element in enumerate(labels): + labels[index] = self.config["learning_map"].get(element, -1) + else: + labels = np.negative(np.ones(coordinates.shape[0])) + + # REMOVE UNLABELED POINTS IF NOT IN TESTING + if "label" in self.files.keys(): + labelled_idx = labels != -1 + coordinates = coordinates[labelled_idx] + features = features[labelled_idx] + labels = labels[labelled_idx] + selected_idx = selected_idx[labelled_idx] + if self.label_mapping_function is not None: + mapped_labels = self.label_mapping_function(labels) + else: + mapped_labels = np.copy(labels) + + discrete_coords, unique_map = ME.utils.sparse_quantize( + coordinates=coordinates, + return_index=True, + quantization_size=self.voxel_size, + ) + + unique_feats = features[unique_map] + unique_labels = labels[unique_map] + unique_mapped_labels = mapped_labels[unique_map] + selected_idx = selected_idx[unique_map] + + return ( + discrete_coords, + unique_feats, + unique_labels, + selected_idx, + unique_mapped_labels, + t, + ) + + +class SemanticKITTIRestrictedDatasetCleanSplit(SemanticKITTIRestrictedDataset): + def __init__( + self, + clean_mask, + config_file="config/dataset.yaml", + split="train", + voxel_size=0.05, + downsampling=80000, + augment=False, + label_mapping=None, + ): + super().__init__( + config_file, split, voxel_size, downsampling, augment, label_mapping + ) + self.filenames = np.array(self.filenames)[clean_mask] + self.num_files = len(self.filenames) + for key in self.files.keys(): + self.files[key] = np.array(self.files[key])[clean_mask] + + +class SemanticPOSSDataset(torch.utils.data.Dataset): + def __init__( + self, + config_file="config/semposs_dataset.yaml", + split="train", + voxel_size=0.05, + downsampling=80000, + augment=False, + ): + """Load data from given dataset directory.""" + + with open(config_file, "r") as f: + self.config = yaml.safe_load(f) + + self.files = {"input": []} + if split != "test": + self.files["label"] = [] + self.filenames = [] + + self.voxel_size = voxel_size + self.downsampling = downsampling + self.augment = False + if split == "train" and augment: + self.augment = True + self.clip_bounds = None + self.scale_augmentation_bound = (0.95, 1.05) + self.rotation_augmentation_bound = ( + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + ) + self.translation_augmentation_ratio_bound = None + + self.voxelizer = Voxelizer( + voxel_size=self.voxel_size, + clip_bound=self.clip_bounds, + use_augmentation=self.augment, + scale_augmentation_bound=self.scale_augmentation_bound, + rotation_augmentation_bound=self.rotation_augmentation_bound, + translation_augmentation_ratio_bound=self.translation_augmentation_ratio_bound, + ignore_label=-1, + ) + + for sequence in self.config["split_sequence"][split]: + for idx, type in enumerate(self.files.keys()): + files_path = os.path.join( + self.config["dataset_path"], + "sequences", + sequence, + self.config["folder_name"][type], + ) + if not os.path.exists(files_path): + raise RuntimeError("Point cloud directory missing: " + files_path) + files = os.listdir(files_path) + data = sorted([os.path.join(files_path, f) for f in files]) + if len(data) == 0: + raise RuntimeError("Missing data for " + type) + self.files[type].extend(data) + if idx == 0: + self.filenames.extend(data) + + learning_map = self.config["learning_map"] + self.learning_map_function = np.vectorize(lambda x: learning_map[x]) + + self.num_files = len(self.filenames) + + def __len__(self): + return self.num_files + + def __getitem__(self, t): + pc_filename = self.files["input"][t] + scan = np.fromfile(pc_filename, dtype=np.float32) + scan = scan.reshape((-1, 4)) + coordinates = scan[:, 0:3] # get xyz + remissions = scan[:, 3] # get remission + + features = np.ones((coordinates.shape[0], 1)) + + # AUGMENTATION + if self.augment: + # DOWNSAMPLING + selected_idx = np.random.choice( + coordinates.shape[0], self.downsampling, replace=False + ) + coordinates = coordinates[selected_idx] + features = features[selected_idx] + + # TRANSFORMATIONS + voxel_mtx, affine_mtx = self.voxelizer.get_transformation_matrix() + + rigid_transformation = affine_mtx @ voxel_mtx + # Apply transformations + + homo_coords = np.hstack( + ( + coordinates, + np.ones((coordinates.shape[0], 1), dtype=coordinates.dtype), + ) + ) + # coords = np.floor(homo_coords @ rigid_transformation.T[:, :3]) + coordinates = homo_coords @ rigid_transformation.T[:, :3] + else: + selected_idx = np.arange(coordinates.shape[0]) + + if "label" in self.files.keys(): + label_filename = self.files["label"][t] + labels = np.fromfile(label_filename, dtype=np.int32) + labels = labels.reshape((-1)) + labels = labels & 0xFFFF + if self.augment: + labels = labels[selected_idx] + labels = self.learning_map_function(labels) + else: + labels = np.negative(np.ones(coordinates.shape[0])) + + # REMOVE UNLABELED POINTS IF NOT IN TESTING + if "label" in self.files.keys(): + labelled_idx = labels != -1 + coordinates = coordinates[labelled_idx] + features = features[labelled_idx] + labels = labels[labelled_idx] + selected_idx = selected_idx[labelled_idx] + + discrete_coords, unique_map = ME.utils.sparse_quantize( + coordinates=coordinates, + return_index=True, + quantization_size=self.voxel_size, + ) + + unique_feats = features[unique_map] + unique_labels = labels[unique_map] + selected_idx = selected_idx[unique_map] + + return discrete_coords, unique_feats, unique_labels, selected_idx, t + + +class SemanticPOSSRestrictedDataset(torch.utils.data.Dataset): + def __init__( + self, + config_file="config/semposs_dataset.yaml", + split="train", + voxel_size=0.05, + downsampling=80000, + augment=False, + label_mapping=None, + ): + """Load data from given dataset directory.""" + + with open(config_file, "r") as f: + self.config = yaml.safe_load(f) + + self.files = {"input": []} + if split != "test": + self.files["label"] = [] + self.filenames = [] + + self.voxel_size = voxel_size + self.downsampling = downsampling + self.augment = False + if split == "train" and augment: + self.augment = True + self.clip_bounds = None + self.scale_augmentation_bound = (0.95, 1.05) + self.rotation_augmentation_bound = ( + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + (-np.pi / 20, np.pi / 20), + ) + self.translation_augmentation_ratio_bound = None + + self.voxelizer = Voxelizer( + voxel_size=self.voxel_size, + clip_bound=self.clip_bounds, + use_augmentation=self.augment, + scale_augmentation_bound=self.scale_augmentation_bound, + rotation_augmentation_bound=self.rotation_augmentation_bound, + translation_augmentation_ratio_bound=self.translation_augmentation_ratio_bound, + ignore_label=-1, + ) + + for sequence in self.config["split_sequence"][split]: + for idx, type in enumerate(self.files.keys()): + files_path = os.path.join( + self.config["dataset_path"], + "sequences", + sequence, + self.config["folder_name"][type], + ) + if not os.path.exists(files_path): + raise RuntimeError("Point cloud directory missing: " + files_path) + files = os.listdir(files_path) + data = sorted([os.path.join(files_path, f) for f in files]) + if len(data) == 0: + raise RuntimeError("Missing data for " + type) + self.files[type].extend(data) + if idx == 0: + self.filenames.extend(data) + + learning_map = self.config["learning_map"] + self.learning_map_function = np.vectorize(lambda x: learning_map[x]) + + self.num_files = len(self.filenames) + + if label_mapping is not None: + self.label_mapping_function = np.vectorize(lambda x: label_mapping[x]) + else: + self.label_mapping_function = None + + def __len__(self): + return self.num_files + + def __getitem__(self, t): + pc_filename = self.files["input"][t] + scan = np.fromfile(pc_filename, dtype=np.float32) + scan = scan.reshape((-1, 4)) + coordinates = scan[:, 0:3] # get xyz + remissions = scan[:, 3] # get remission + + features = np.ones((coordinates.shape[0], 1)) + + # AUGMENTATION + if self.augment: + # DOWNSAMPLING + selected_idx = np.random.choice( + coordinates.shape[0], self.downsampling, replace=False + ) + selected_idx = np.sort(selected_idx) + coordinates = coordinates[selected_idx] + features = features[selected_idx] + + # TRANSFORMATIONS + voxel_mtx, affine_mtx = self.voxelizer.get_transformation_matrix() + + rigid_transformation = affine_mtx @ voxel_mtx + # Apply transformations + + homo_coords = np.hstack( + ( + coordinates, + np.ones((coordinates.shape[0], 1), dtype=coordinates.dtype), + ) + ) + # coords = np.floor(homo_coords @ rigid_transformation.T[:, :3]) + coordinates = homo_coords @ rigid_transformation.T[:, :3] + else: + selected_idx = np.arange(coordinates.shape[0]) + + if "label" in self.files.keys(): + label_filename = self.files["label"][t] + labels = np.fromfile(label_filename, dtype=np.int32) + labels = labels.reshape((-1)) + labels = labels & 0xFFFF + if self.augment: + labels = labels[selected_idx] + labels = self.learning_map_function(labels) + else: + labels = np.negative(np.ones(coordinates.shape[0])) + + # REMOVE UNLABELED POINTS IF NOT IN TESTING + if "label" in self.files.keys(): + labelled_idx = labels != -1 + coordinates = coordinates[labelled_idx] + features = features[labelled_idx] + labels = labels[labelled_idx] + selected_idx = selected_idx[labelled_idx] + if self.label_mapping_function is not None: + mapped_labels = self.label_mapping_function(labels) + else: + mapped_labels = np.copy(labels) + + discrete_coords, unique_map = ME.utils.sparse_quantize( + coordinates=coordinates, + return_index=True, + quantization_size=self.voxel_size, + ) + + unique_feats = features[unique_map] + unique_labels = labels[unique_map] + unique_mapped_labels = mapped_labels[unique_map] + selected_idx = selected_idx[unique_map] + + return ( + discrete_coords, + unique_feats, + unique_labels, + selected_idx, + unique_mapped_labels, + t, + ) + + +class SemanticPOSSRestrictedDatasetCleanSplit(SemanticPOSSRestrictedDataset): + def __init__( + self, + clean_mask, + config_file="config/semposs_dataset.yaml", + split="train", + voxel_size=0.05, + downsampling=80000, + augment=False, + label_mapping=None, + ): + super().__init__( + config_file, split, voxel_size, downsampling, augment, label_mapping + ) + self.filenames = np.array(self.filenames)[clean_mask] + self.num_files = len(self.filenames) + for key in self.files.keys(): + self.files[key] = np.array(self.files[key])[clean_mask] + + +class dataset_wrapper(torch.utils.data.Dataset): + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, t): + to_ret = self.dataset.__getitem__(t)[:-1] + self.dataset.__getitem__(t) + + return to_ret diff --git a/utils/scheduler.py b/utils/scheduler.py new file mode 100644 index 0000000..c6d93d0 --- /dev/null +++ b/utils/scheduler.py @@ -0,0 +1,119 @@ +''' +https://github.com/Lightning-AI/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py +''' + +import math +import warnings +from typing import List +from torch.optim import Optimizer + +from torch.optim.lr_scheduler import _LRScheduler + +class LinearWarmupCosineAnnealingLR(_LRScheduler): + """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr + and base_lr followed by a cosine annealing schedule between base_lr and eta_min. + .. warning:: + It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` + after each iteration as calling it after each epoch will keep the starting lr at + warmup_start_lr for the first epoch which is 0 in most cases. + .. warning:: + passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. + It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of + :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing + epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling + train and validation methods. + Example: + >>> layer = nn.Linear(10, 1) + >>> optimizer = Adam(layer.parameters(), lr=0.02) + >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) + >>> # + >>> # the default case + >>> for epoch in range(40): + ... # train(...) + ... # validate(...) + ... scheduler.step() + >>> # + >>> # passing epoch param case + >>> for epoch in range(40): + ... scheduler.step(epoch) + ... # train(...) + ... # validate(...) + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_epochs: int, + max_epochs: int, + warmup_start_lr: float = 0.0, + eta_min: float = 0.0, + last_epoch: int = -1, + ) -> None: + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + """Compute learning rate using chainable form of the scheduler.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", + UserWarning, + ) + + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + if self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + if self.last_epoch == self.warmup_epochs: + return self.base_lrs + if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + / ( + 1 + + math.cos( + math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) + ) + ) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self) -> List[float]: + """Called when epoch is passed as a param to the `step` function of the scheduler.""" + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + return [ + self.eta_min + + 0.5 + * (base_lr - self.eta_min) + * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + for base_lr in self.base_lrs + ] \ No newline at end of file diff --git a/utils/sinkhorn_knopp.py b/utils/sinkhorn_knopp.py new file mode 100644 index 0000000..7ac3a4f --- /dev/null +++ b/utils/sinkhorn_knopp.py @@ -0,0 +1,44 @@ +import torch + + +class SinkhornKnopp(torch.nn.Module): + def __init__(self, num_iters=3, epsilon=0.05): + super().__init__() + self.num_iters = num_iters + self.epsilon = epsilon + + @torch.no_grad() + def forward(self, features, head, queue=None): + if queue is None or queue.shape[0] == 0: + queue = None + if queue is not None: + features = torch.vstack((features, queue)) + + features = torch.nn.functional.normalize(features, dim=1, p=2) + head = torch.nn.functional.normalize(head, dim=1, p=2) + + logits = features@head + + logits = logits.to(torch.float64) + Q = torch.exp(logits / self.epsilon).t() + B = Q.shape[1] + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + Q /= sum_Q + + for it in range(self.num_iters): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the colomns must sum to 1 so that Q is an assignment + to_ret = Q.t() if queue is None else Q.t()[:-queue.shape[0]] + + return to_ret \ No newline at end of file diff --git a/utils/unkn_labels.py b/utils/unkn_labels.py new file mode 100644 index 0000000..9897144 --- /dev/null +++ b/utils/unkn_labels.py @@ -0,0 +1,40 @@ +MAX_SPLIT_NUM = 4 + +def unknown_labels(split, dataset_config): + '''Creates the set of unknown labels in a way that all the classes in the same splits have roughly the same number of points''' + class_percentages = dataset_config['content'] + label_percentages = { + i: 0.0 for i in dataset_config['learning_map_inv'].keys()} + for label, percentage in class_percentages.items(): + mapped_label = dataset_config['learning_map'][label] + label_percentages[mapped_label] += percentage + del(label_percentages[-1]) + label_percentages = sorted( + label_percentages, key=label_percentages.get, reverse=True) + novel_classes_per_split = int(len(label_percentages)/MAX_SPLIT_NUM) + act_splitting = [novel_classes_per_split for _ in range(MAX_SPLIT_NUM)] + tot_num = sum(act_splitting) + i = 0 + while tot_num != len(label_percentages): + act_splitting[i] += 1 + i += 1 + tot_num = sum(act_splitting) + start = sum(act_splitting[:split]) + end = start + act_splitting[split] + return label_percentages[start:end] + + +def label_mapping(unknown_labels, all_labels): + new_label = -1 + label_mapping = {} + label_mapping_inv = {} + for label in all_labels: + if label not in unknown_labels: + label_mapping[label] = new_label + label_mapping_inv[new_label] = label + new_label += 1 + label_mapping = {**label_mapping, ** + {unk: new_label for unk in unknown_labels}} + del(label_mapping[-1]) + del(label_mapping_inv[-1]) + return label_mapping, label_mapping_inv, new_label \ No newline at end of file diff --git a/utils/voxelizer.py b/utils/voxelizer.py new file mode 100644 index 0000000..93d85f7 --- /dev/null +++ b/utils/voxelizer.py @@ -0,0 +1,74 @@ +import numpy as np +from scipy import linalg +import collections + +def M(axis, theta): + return linalg.expm(np.cross(np.eye(3), axis/linalg.norm(axis)*theta)) + +class Voxelizer: + + def __init__(self, + voxel_size=0.05, + clip_bound=None, + use_augmentation=False, + scale_augmentation_bound=None, + rotation_augmentation_bound=None, + translation_augmentation_ratio_bound=None, + ignore_label=255): + """ + Args: + voxel_size: side length of a voxel + clip_bound: boundary of the voxelizer. Points outside the bound will be deleted + expects either None or an array like ((-100, 100), (-100, 100), (-100, 100)). + scale_augmentation_bound: None or (0.9, 1.1) + rotation_augmentation_bound: None or ((np.pi / 6, np.pi / 6), None, None) for 3 axis. + Use random order of x, y, z to prevent bias. + translation_augmentation_bound: ((-5, 5), (0, 0), (-10, 10)) + ignore_label: label assigned for ignore (not a training label). + """ + self.voxel_size = voxel_size + self.clip_bound = clip_bound + if ignore_label is not None: + self.ignore_label = ignore_label + else: + self.ignore_label = -100 + # Augmentation + self.use_augmentation = use_augmentation + self.scale_augmentation_bound = scale_augmentation_bound + self.rotation_augmentation_bound = rotation_augmentation_bound + self.translation_augmentation_ratio_bound = translation_augmentation_ratio_bound + + def get_transformation_matrix(self): + voxelization_matrix, rotation_matrix = np.eye(4), np.eye(4) + + # Transform pointcloud coordinate to voxel coordinate. + # 1. Random rotation + rot_mat = np.eye(3) + if self.use_augmentation and self.rotation_augmentation_bound is not None: + if isinstance(self.rotation_augmentation_bound, collections.Iterable): + rot_mats = [] + for axis_ind, rot_bound in enumerate(self.rotation_augmentation_bound): + theta = 0 + axis = np.zeros(3) + axis[axis_ind] = 1 + if rot_bound is not None: + theta = np.random.uniform(*rot_bound) + rot_mats.append(M(axis, theta)) + # Use random order + np.random.shuffle(rot_mats) + rot_mat = rot_mats[0] @ rot_mats[1] @ rot_mats[2] + else: + raise ValueError() + rotation_matrix[:3, :3] = rot_mat + # 2. Scale and translate to the voxel space. + scale = 1 + if self.use_augmentation and self.scale_augmentation_bound is not None: + scale *= np.random.uniform(*self.scale_augmentation_bound) + np.fill_diagonal(voxelization_matrix[:3, :3], scale) + + # 3. Translate + if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: + tr = [np.random.uniform(*t) for t in self.translation_augmentation_ratio_bound] + rotation_matrix[:3, 3] = tr + # Get final transformation matrix. + return voxelization_matrix, rotation_matrix \ No newline at end of file