diff --git a/src/cultionet/data/datasets.py b/src/cultionet/data/datasets.py index 1a99ab28..805e12d3 100644 --- a/src/cultionet/data/datasets.py +++ b/src/cultionet/data/datasets.py @@ -20,7 +20,7 @@ from ..augment.augmenters import Augmenters from ..errors import TensorShapeError from ..utils.logging import set_color_logger -from ..utils.model_preprocessing import TqdmParallel +from ..utils.model_preprocessing import ParallelProgress from ..utils.normalize import NormValues from .data import Data from .spatial_dataset import SpatialDataset @@ -257,7 +257,7 @@ def check_dims( backend="loky", n_jobs=self.processes, ): - with TqdmParallel( + with ParallelProgress( tqdm_kwargs={ "total": len(self), "desc": "Checking dimensions", diff --git a/src/cultionet/data/lookup.py b/src/cultionet/data/lookup.py deleted file mode 100644 index 641ae371..00000000 --- a/src/cultionet/data/lookup.py +++ /dev/null @@ -1,470 +0,0 @@ -NON_AG = frozenset(("unknown", "developed", "trees")) -NON_CROP = frozenset(("hay", "pasture")) - -CDL_COLORS = dict( - background="#ffffff", - all_crops="#E4A520", - maize="#ffd300", - spring_maize="#dca50c", - maize1="#dca50c", - maize2="#ffd300", - dbl_maize="#b29300", - dbl_spring_maize="#dca50c", - dbl_maize1="#b29300", - dbl_maize2="#b29300", - dbl_cotton="#fe2725", - soybeans="#267000", - spring_soybeans="#267000", - soybeans1="#359c00", - soybeans2="#267000", - dbl_soybeans="#1a4f00", - dbl_spring_soybeans="#1a4f00", - dbl_soybeans1="#1a4f00", - dbl_soybeans2="#1a4f00", - cotton="#fe2725", - peanuts="#70a500", - dbl_wheat="#a57000", - dbl_winter_wheat_soy="#707002", - dbl_winter_wheat_maize="#ffd301", - millet="#70004a", - spring_millet="#8f005f", - pecans="#b4705b", - rye="#ac017c", - oats="#a15989", - dbl_cropping="#9a622a", - sorghum="#fe9e0c", - spring_sorghum="#eb9109", - dbl_sorghum="#905906", - winter_wheat="#a57000", - spring_wheat="#d9b56b", - durum_wheat="#896454", - dry_beans="#a40000", - safflower="#d6d700", - dbl_safflower="#d6d700", - rape_seed="#d1ff00", - mustard="#00b04a", - buckwheat="#d69dbd", - sudangrass="#b663a0", - dbl_sudangrass="#8d5060", - onions="#ff6966", - camelina="#02ad4c", - peas="#54ff00", - watermelons="#ff6766", - honeydew_melons="#ff6766", - dbl_soy_oats="#267000", - dbl_maize_soy="#ffd300", - sweet_potatoes="#702601", - hops="#00ae4c", - pumpkins="#ff6766", - dbl_durum_wheat_sorghum="#ff9e0a", - dbl_barley_sorghum="#ff9e0a", - triticale="#d69dbd", - pop_orn_maize="#dca50c", - almonds="#00a682", - pistachios="#00ff8c", - aquaculture="#01ffff", - dbl_winter_wheat_cotton="#a57000", - dbl_soy_cotton="#267000", - sweet_maize="#dca50c", - sunflower="#ffff00", - dbl_sunflower="#ffff00", - flaxseed="#8099fe", - clover="#e8c0ff", - sod_grass="#afffdc", - lentils="#00deaf", - sugarbeets="#a700e4", - walnuts="#ead6af", - dbl_oats_maize="#ffd300", - herbs="#7fd3ff", - blueberries="#000098", - peaches="#ff8daa", - pears="#b29b71", - grapes="#6f4489", - orchard="#6f4489", - cucumbers="#fd6666", - chick_peaks="#00b04a", - misc_fruits_vegs="#ff6766", - carrots="#ff6666", - asparagus="#ff6666", - garlic="#ff6666", - cantaloupes="#ff6666", - prunes="#ff8fab", - olives="#334a33", - oranges="#e37026", - broccoli="#ff6666", - cabbage="#ff6666", - cauliflower="#ff6666", - celery="#ff6666", - peppers="#ff6766", - pomegranates="#b09970", - nectarines="#ff8fab", - greens="#ff6666", - plums="#ff8fab", - strawberries="#ff6666", - rice="#01a8e6", - potatoes="#702601", - alfalfa="#df91c7", - other_crops="#00ae4c", - sugarcane="#648d6c", - dbl_sugarcane="#648d6c", - speltz="#d69dbd", - winter_barley="#e2007d", - barley="#e240a4", - dbl_barley="#e2007d", - dbl_winter_wheat_sorghum="#a57001", - dbl_barley_maize="#ffd300", - dbl_barley_soy="#267000", - canola="#d1ff00", - switchgrass="#00ae4c", - tomatoes="#f3a378", - tobacco="#008539", - pastureland="#e9ffbf", - grassland_pasture="#e9ffbf", - savanna="#739f73", - other_hay="#a5f18c", - dbl_hay="#7fef81", - fallow="#bfbf77", - harvested="#bfbf77", - planted="#9bbf77", - cherries="#ff00ff", - apples="#bb004f", - squash="#ff6766", - apricots="#ff8fab", - vetch="#00b04a", - lettuce="#ff6666", - turnips="#ff6766", - eggplants="#ff6766", - radishes="#ff6766", - gourds="#ff6666", - cranberries="#ff6666", - christmas_trees="#007776", - other_tree_crops="#b29b71", - citrus="#ffff7d", - deciduous_forest="#92cc92", - evergreen_forest="#92ccaf", - mixed_forest="#afcc92", - forest="#4E6507", - deforestation="#ff35e4", - shrubland="#c7d5a0", - cactus="#c7d5a0", - woody_wetlands="#7fb39a", - herbaceous_wetlands="#7fb2b3", - wetlands="#7fb2b3", - barren="#cdbfa4", - plantation="#7833ad", - open_water="#5990B1", - water="#4c70a4", - developed="#707A88", - developed_high="#5f0100", -) - -CDL_LABELS = dict( - background=0, - cropland=1, - maize=1, - cotton=2, - rice=3, - sorghum=4, - soybeans=5, - sunflower=6, - spring_maize=7, - spring_soybeans=8, - spring_sorghum=9, - peanuts=10, - tobacco=11, - sweet_maize=12, - pop_orn_maize=13, - mint=14, - maize1=15, - maize2=16, - soybeans1=17, - soybeans2=18, - spring_millet=19, - barley=20, - winter_barley=21, - durum_wheat=22, - spring_wheat=23, - winter_wheat=24, - other_small_grains=25, - dbl_winter_wheat_soy=26, - rye=27, - oats=28, - millet=29, - speltz=30, - canola=31, - flaxseed=32, - safflower=33, - rape_seed=34, - mustard=35, - alfalfa=36, - other_hay=37, - camelina=38, - buckwheat=39, - sudangrass=40, - sugarbeets=41, - dry_beans=42, - potatoes=43, - other_crops=44, - sugarcane=45, - sweet_potatoes=46, - misc_fruits_vegs=47, - watermelons=48, - onions=49, - cucumbers=50, - chick_peaks=51, - lentils=52, - peas=53, - tomatoes=54, - hops=56, - herbs=57, - clover=58, - sod_grass=59, - switchgrass=60, - fallow=61, - harvested=62, - planted=63, - young=64, - cherries=66, - peaches=67, - apples=68, - grapes=69, - christmas_trees=70, - other_tree_crops=71, - citrus=72, - pecans=74, - almonds=75, - walnuts=76, - pears=77, - orchard=78, - tilled=79, - dbl_maize=80, - dbl_cotton=81, - dbl_sorghum=82, - dbl_soybeans=83, - dbl_sunflower=84, - dbl_tobacco=85, - dbl_millet=86, - dbl_hay=87, - dbl_sudangrass=88, - dbl_dry_beans=89, - dbl_other_crops=90, - dbl_sugarcane=91, - aquaculture=92, - dbl_onions=93, - dbl_rice=94, - dbl_alfalfa=95, - dbl_clover=96, - dbl_wheat=97, - dbl_barley=98, - dbl_oats=99, - dbl_spring_maize=100, - dbl_spring_soybeans=101, - dbl_safflower=104, - dbl_cropping=110, - open_water=111, - developed_open=121, - developed_low=122, - developed_medium=123, - developed_high=124, - barren=131, - plantation=138, - eucalyptus_plantation=139, - pine_plantation=140, - deciduous_forest=141, - evergreen_forest=142, - mixed_forest=143, - forest=144, - deforestation=145, - reforestation=146, - shrubland=152, - cactus=153, - savanna=173, - grassland=174, - pastureland=175, - grassland_pasture=176, - woody_wetlands=190, - herbaceous_wetlands=195, - pistachios=204, - triticale=205, - carrots=206, - asparagus=207, - garlic=208, - cantaloupes=209, - prunes=210, - olives=211, - oranges=212, - honeydew_melons=213, - broccoli=214, - peppers=216, - pomegranates=217, - nectarines=218, - greens=219, - plums=220, - strawberries=221, - squash=222, - apricots=223, - vetch=224, - dbl_winter_wheat_maize=225, - dbl_oats_maize=226, - lettuce=227, - pumpkins=229, - dbl_durum_wheat_sorghum=234, - dbl_barley_sorghum=235, - dbl_winter_wheat_sorghum=236, - dbl_barley_maize=237, - dbl_winter_wheat_cotton=238, - dbl_soy_cotton=239, - dbl_soy_oats=240, - dbl_maize_soy=241, - blueberries=242, - cabbage=243, - cauliflower=244, - celery=245, - radishes=246, - turnips=247, - eggplants=248, - gourds=249, - cranberries=250, - dbl_barley_soy=254, -) - -CDL_CROP_LABELS = dict( - maize=1, - cotton=2, - rice=3, - sorghum=4, - soybeans=5, - sunflower=6, - spring_maize=7, - spring_soybeans=8, - spring_sorghum=9, - peanuts=10, - tobacco=11, - sweet_maize=12, - pop_orn_maize=13, - mint=14, - maize1=15, - maize2=16, - soybeans1=17, - soybeans2=18, - spring_millet=19, - barley=20, - winter_barley=21, - durum_wheat=22, - spring_wheat=23, - winter_wheat=24, - other_small_grains=25, - dbl_winter_wheat_soy=26, - rye=27, - oats=28, - millet=29, - speltz=30, - canola=31, - flaxseed=32, - safflower=33, - rape_seed=34, - mustard=35, - alfalfa=36, - other_hay=37, - camelina=38, - buckwheat=39, - sudangrass=40, - sugarbeets=41, - dry_beans=42, - potatoes=43, - other_crops=44, - sugarcane=45, - sweet_potatoes=46, - misc_fruits_vegs=47, - watermelons=48, - onions=49, - cucumbers=50, - chick_peaks=51, - lentils=52, - peas=53, - tomatoes=54, - hops=56, - herbs=57, - clover=58, - sod_grass=59, - switchgrass=60, - fallow=61, - cherries=66, - peaches=67, - apples=68, - grapes=69, - citrus=72, - pecans=74, - almonds=75, - walnuts=76, - pears=77, - dbl_maize=80, - dbl_cotton=81, - dbl_sorghum=82, - dbl_soybeans=83, - dbl_sunflower=84, - dbl_tobacco=85, - dbl_millet=86, - dbl_hay=87, - dbl_sudangrass=88, - dbl_dry_beans=89, - dbl_other_crops=90, - dbl_sugarcane=91, - aquaculture=92, - dbl_onions=93, - dbl_rice=94, - dbl_alfalfa=95, - dbl_clover=96, - dbl_wheat=97, - dbl_barley=98, - dbl_oats=99, - dbl_spring_maize=100, - dbl_spring_soybeans=101, - dbl_safflower=104, - dbl_cropping=110, - pistachios=204, - triticale=205, - carrots=206, - asparagus=207, - garlic=208, - cantaloupes=209, - prunes=210, - olives=211, - oranges=212, - honeydew_melons=213, - broccoli=214, - peppers=216, - pomegranates=217, - nectarines=218, - greens=219, - plums=220, - strawberries=221, - squash=222, - apricots=223, - vetch=224, - dbl_winter_wheat_maize=225, - dbl_oats_maize=226, - lettuce=227, - pumpkins=229, - dbl_durum_wheat_sorghum=234, - dbl_barley_sorghum=235, - dbl_winter_wheat_sorghum=236, - dbl_barley_maize=237, - dbl_winter_wheat_cotton=238, - dbl_soy_cotton=239, - dbl_soy_oats=240, - dbl_maize_soy=241, - blueberries=242, - cabbage=243, - cauliflower=244, - celery=245, - radishes=246, - turnips=247, - eggplants=248, - gourds=249, - cranberries=250, - dbl_barley_soy=254, -) - -CDL_LABELS_r = {v: k for k, v in CDL_LABELS.items()} -CDL_CROP_LABELS_r = {v: k for k, v in CDL_CROP_LABELS.items()} diff --git a/src/cultionet/data/spatial_dataset.py b/src/cultionet/data/spatial_dataset.py index 25a4cf33..c23c9c1e 100644 --- a/src/cultionet/data/spatial_dataset.py +++ b/src/cultionet/data/spatial_dataset.py @@ -7,7 +7,7 @@ from shapely.geometry import box from torch.utils.data import Dataset -from ..utils.model_preprocessing import TqdmParallel +from ..utils.model_preprocessing import ParallelProgress def get_box_id(data_id: str, *bounds) -> tuple: @@ -34,7 +34,7 @@ def to_frame(self, id_column: str, n_jobs: int) -> gpd.GeoDataFrame: """Converts the Dataset to a GeoDataFrame.""" with parallel_backend(backend="loky", n_jobs=n_jobs): - with TqdmParallel( + with ParallelProgress( tqdm_kwargs={ "total": len(self), "desc": "Building GeoDataFrame", diff --git a/src/cultionet/enums/__init__.py b/src/cultionet/enums/__init__.py index 9806c9d3..666246c1 100644 --- a/src/cultionet/enums/__init__.py +++ b/src/cultionet/enums/__init__.py @@ -65,12 +65,7 @@ class ModelNames(StrEnum): class ModelTypes(StrEnum): - UNET = 'unet' - RESUNET = 'resunet' - UNET3PSI = 'UNet3Psi' - RESUNET3PSI = 'ResUNet3Psi' TOWERUNET = 'TowerUNet' - TRESAUNET = 'TemporalResAUNet' class ResBlockTypes(StrEnum): diff --git a/src/cultionet/layers/encodings.py b/src/cultionet/layers/encodings.py index 80f05d1b..29576d1f 100644 --- a/src/cultionet/layers/encodings.py +++ b/src/cultionet/layers/encodings.py @@ -28,18 +28,3 @@ def get_sinusoid_encoding_table( sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.tensor(sinusoid_table, dtype=torch.float32) - - -def cartesian(lon: torch.Tensor, lat: torch.Tensor) -> torch.Tensor: - """ - Source: - https://github.com/nasaharvest/presto/blob/main/presto/presto.py - """ - with torch.no_grad(): - lon_rad = torch.deg2rad(lon) - lat_rad = torch.deg2rad(lat) - x = torch.cos(lat_rad) * torch.cos(lon_rad) - y = torch.cos(lat_rad) * torch.sin(lon_rad) - z = torch.sin(lat_rad) - - return torch.stack([x, y, z], dim=-1) diff --git a/src/cultionet/losses/losses.py b/src/cultionet/losses/losses.py index 9c747d2c..4216dd7a 100644 --- a/src/cultionet/losses/losses.py +++ b/src/cultionet/losses/losses.py @@ -2,11 +2,9 @@ import warnings import einops -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import torchmetrics from kornia.contrib import distance_transform try: @@ -14,40 +12,6 @@ except ImportError: topnn = None -from ..data.data import Data - - -class FieldOfJunctionsLoss(nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, - patches: torch.Tensor, - image_patches: torch.Tensor, - ) -> torch.Tensor: - """Compute the objective of our model (see Equation 8 of the paper).""" - - # Compute negative log-likelihood for each patch (shape [N, H', W']) - loss_per_patch = einops.reduce( - ( - einops.rearrange(image_patches, 'b c p k h w -> b 1 c p k h w') - - patches - ) - ** 2, - 'b n c p k h w -> b n c h w', - 'mean', - ) - loss_per_patch = einops.reduce( - loss_per_patch, 'b n c h w -> b n h w', 'sum' - ) - # Reduce to the batch mean - loss_per_patch = einops.reduce( - loss_per_patch, 'b n h w -> n h w', 'mean' - ) - - return loss_per_patch.mean() - class LossPreprocessing(nn.Module): def __init__( @@ -391,108 +355,6 @@ def forward( return self.loss_func(inputs, targets) -class FocalLoss(nn.Module): - """Focal loss. - - Reference: - https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook - """ - - def __init__( - self, - alpha: float = 0.8, - gamma: float = 2.0, - weight: T.Optional[torch.Tensor] = None, - label_smoothing: T.Optional[float] = 0.1, - ): - super().__init__() - - self.alpha = alpha - self.gamma = gamma - - self.preprocessor = LossPreprocessing( - inputs_are_logits=True, apply_transform=True - ) - self.cross_entropy_loss = nn.CrossEntropyLoss( - weight=weight, reduction="none", label_smoothing=label_smoothing - ) - - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor - ) -> torch.Tensor: - inputs, targets = self.preprocessor(inputs, targets) - ce_loss = self.cross_entropy_loss(inputs, targets.half()) - ce_exp = torch.exp(-ce_loss) - focal_loss = self.alpha * (1.0 - ce_exp) ** self.gamma * ce_loss - - return focal_loss.mean() - - -class QuantileLoss(nn.Module): - """Loss function for quantile regression. - - Reference: - https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/metrics.html#QuantileLoss - - THE MIT License - - Copyright 2020 Jan Beitner - """ - - def __init__(self, quantiles: T.Tuple[float, float, float]): - super().__init__() - - self.quantiles = quantiles - - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor - ) -> torch.Tensor: - """Performs a single forward pass. - - Args: - inputs: Predictions from model (probabilities, logits or labels). - targets: Ground truth values. - - Returns: - Quantile loss (float) - """ - losses = [] - for i, q in enumerate(self.quantiles): - errors = targets - inputs[:, i] - losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(1)) - loss = torch.cat(losses, dim=1).sum(dim=1).mean() - - return loss - - -class WeightedL1Loss(nn.Module): - """Weighted L1Loss loss.""" - - def __init__(self): - super().__init__() - - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor - ) -> torch.Tensor: - """Performs a single forward pass. - - Args: - inputs: Predictions from model. - targets: Ground truth values. - - Returns: - Loss (float) - """ - inputs = inputs.contiguous().view(-1) - targets = targets.contiguous().view(-1) - - mae = torch.abs(inputs - targets) - weight = inputs + targets - loss = (mae * weight).mean() - - return loss - - class MSELoss(nn.Module): """MSE loss.""" @@ -576,50 +438,6 @@ def forward( return torch.einsum("bhw, bhw -> bhw", distances, 1.0 - probs).mean() -class MultiScaleSSIMLoss(nn.Module): - """Multi-scale Structural Similarity Index Measure loss.""" - - def __init__(self): - super().__init__() - - self.msssim = torchmetrics.MultiScaleStructuralSimilarityIndexMeasure( - gaussian_kernel=False, - kernel_size=3, - data_range=1.0, - k1=1e-4, - k2=9e-4, - ) - - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor, data: Data - ) -> torch.Tensor: - """Performs a single forward pass. - - Args: - inputs: Predicted probabilities. - targets: Ground truth inverse distance transform, where distances - along edges are 1. - data: Data object used to extract dimensions. - - Returns: - Loss (float) - """ - height = ( - int(data.height) if data.batch is None else int(data.height[0]) - ) - width = int(data.width) if data.batch is None else int(data.width[0]) - batch_size = 1 if data.batch is None else data.batch.unique().size(0) - - inputs = self.gc(inputs.unsqueeze(1), batch_size, height, width) - targets = self.gc(targets.unsqueeze(1), batch_size, height, width).to( - dtype=inputs.dtype - ) - - loss = 1.0 - self.msssim(inputs, targets) - - return loss - - class TopologyLoss(nn.Module): def __init__(self): super().__init__() diff --git a/src/cultionet/losses/topological.py b/src/cultionet/losses/topological.py deleted file mode 100644 index 4ebe150b..00000000 --- a/src/cultionet/losses/topological.py +++ /dev/null @@ -1,304 +0,0 @@ -import typing as T - -import numpy as np -import torch -import gudhi - - -def critical_points( - x: torch.Tensor, -) -> T.Tuple[T.List[np.ndarray], T.List[np.ndarray], T.List[np.ndarray], bool]: - batch_size = x.shape[0] - lh_vector = 1.0 - x.flatten() - cubical_complex = gudhi.CubicalComplex( - dimensions=x.shape, top_dimensional_cells=lh_vector - ) - cubical_complex.persistence(homology_coeff_field=2, min_persistence=0) - cofaces = cubical_complex.cofaces_of_persistence_pairs() - cofaces_batch_size = len(cofaces[0]) - - if (cofaces_batch_size == 0) or (cofaces_batch_size != batch_size): - return None, None, None, False - - pd_lh = [ - np.c_[ - lh_vector[cofaces[0][batch][:, 0]], - lh_vector[cofaces[0][batch][:, 1]], - ] - for batch in range(0, batch_size) - ] - bcp_lh = [ - np.c_[ - cofaces[0][batch][:, 0] // x.shape[-1], - cofaces[0][batch][:, 0] % x.shape[-1], - ] - for batch in range(0, batch_size) - ] - dcp_lh = [ - np.c_[ - cofaces[0][batch][:, 1] // x.shape[-1], - cofaces[0][batch][:, 1] % x.shape[-1], - ] - for batch in range(0, batch_size) - ] - - return pd_lh, bcp_lh, dcp_lh, True - - -def compute_dgm_force( - lh_dgm: np.ndarray, - gt_dgm: np.ndarray, - pers_thresh: float = 0.03, - pers_thresh_perfect: float = 0.99, - do_return_perfect: bool = False, -) -> T.Tuple[np.ndarray, np.ndarray]: - """Compute the persistent diagram of the image. - - Args: - lh_dgm: likelihood persistent diagram. - gt_dgm: ground truth persistent diagram. - pers_thresh: Persistent threshold, which also called dynamic value, which measure the difference. - between the local maximum critical point value with its neighouboring minimum critical point value. - Values smaller than the persistent threshold should be filtered. Default is 0.03. - pers_thresh_perfect: The distance difference between two critical points that can be considered as - correct match. Default is 0.99. - do_return_perfect: Return the persistent point or not from the matching. Default is ``False``. - - Returns: - force_list: The matching between the likelihood and ground truth persistent diagram. - idx_holes_to_fix: The index of persistent points that requires to fix in the following training process. - idx_holes_to_remove: The index of persistent points that require to remove for the following training - process. - """ - lh_pers = abs(lh_dgm[:, 1] - lh_dgm[:, 0]) - if gt_dgm.shape[0] == 0: - gt_pers = None - gt_n_holes = 0 - else: - gt_pers = gt_dgm[:, 1] - gt_dgm[:, 0] - gt_n_holes = gt_pers.size # number of holes in gt - - if (gt_pers is None) or (gt_n_holes == 0): - idx_holes_to_fix = np.array([], dtype=int) - idx_holes_to_remove = np.array(list(set(range(lh_pers.size)))) - idx_holes_perfect = [] - else: - # check to ensure that all gt dots have persistence 1 - tmp = gt_pers > pers_thresh_perfect - - # get "perfect holes" - holes which do not need to be fixed, i.e., find top - # lh_n_holes_perfect indices - # check to ensure that at least one dot has persistence 1; it is the hole - # formed by the padded boundary - # if no hole is ~1 (ie >.999) then just take all holes with max values - tmp = lh_pers > pers_thresh_perfect # old: assert tmp.sum() >= 1 - lh_pers_sorted_indices = np.argsort(lh_pers)[::-1] - if np.sum(tmp) >= 1: - lh_n_holes_perfect = tmp.sum() - idx_holes_perfect = lh_pers_sorted_indices[:lh_n_holes_perfect] - else: - idx_holes_perfect = [] - - # find top gt_n_holes indices - idx_holes_to_fix_or_perfect = lh_pers_sorted_indices[:gt_n_holes] - - # the difference is holes to be fixed to perfect - idx_holes_to_fix = np.array( - list(set(idx_holes_to_fix_or_perfect) - set(idx_holes_perfect)) - ) - - # remaining holes are all to be removed - idx_holes_to_remove = lh_pers_sorted_indices[gt_n_holes:] - - # only select the ones whose persistence is large enough - # set a threshold to remove meaningless persistence dots - pers_thd = pers_thresh - idx_valid = np.where(lh_pers > pers_thd)[0] - idx_holes_to_remove = np.array( - list(set(idx_holes_to_remove).intersection(set(idx_valid))) - ) - - force_list = np.zeros(lh_dgm.shape) - - # push each hole-to-fix to (0,1) - if idx_holes_to_fix.shape[0] > 0: - force_list[idx_holes_to_fix, 0] = 0 - lh_dgm[idx_holes_to_fix, 0] - force_list[idx_holes_to_fix, 1] = 1 - lh_dgm[idx_holes_to_fix, 1] - - # push each hole-to-remove to (0,1) - if idx_holes_to_remove.shape[0] > 0: - force_list[idx_holes_to_remove, 0] = lh_pers[ - idx_holes_to_remove - ] / np.sqrt(2.0) - force_list[idx_holes_to_remove, 1] = -lh_pers[ - idx_holes_to_remove - ] / np.sqrt(2.0) - - if do_return_perfect: - return ( - force_list, - idx_holes_to_fix, - idx_holes_to_remove, - idx_holes_perfect, - ) - - return force_list, idx_holes_to_fix, idx_holes_to_remove - - -def adjust_holes_to_fix( - topo_cp_weight_map: np.ndarray, - topo_cp_ref_map: np.ndarray, - topo_mask: np.ndarray, - hole_indices: np.ndarray, - pairs: np.ndarray, - fill_weight: int, - fill_ref: int, - height: int, - width: int, -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: - mask = ( - (pairs[hole_indices][:, 0] >= 0) - * (pairs[hole_indices][:, 0] < height) - * (pairs[hole_indices][:, 1] >= 0) - * (pairs[hole_indices][:, 1] < width) - ) - indices = ( - pairs[hole_indices][:, 0][mask], - pairs[hole_indices][:, 1][mask], - ) - topo_cp_weight_map[indices] = fill_weight - topo_cp_ref_map[indices] = fill_ref - topo_mask[indices] = 1 - - return topo_cp_weight_map, topo_cp_ref_map, topo_mask - - -def adjust_holes_to_remove( - likelihood: np.ndarray, - topo_cp_weight_map: np.ndarray, - topo_cp_ref_map: np.ndarray, - topo_mask: np.ndarray, - hole_indices: np.ndarray, - pairs_b: np.ndarray, - pairs_d: np.ndarray, - fill_weight: int, - fill_ref: int, - height: int, - width: int, -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: - mask = ( - (pairs_b[hole_indices][:, 0] >= 0) - * (pairs_b[hole_indices][:, 0] < height) - * (pairs_b[hole_indices][:, 1] >= 0) - * (pairs_b[hole_indices][:, 1] < width) - ) - indices = ( - pairs_b[hole_indices][:, 0][mask], - pairs_b[hole_indices][:, 1][mask], - ) - topo_cp_weight_map[indices] = fill_weight - topo_mask[indices] = 1 - - nested_mask = ( - mask - * (pairs_d[hole_indices][:, 0] >= 0) - * (pairs_d[hole_indices][:, 0] < height) - * (pairs_d[hole_indices][:, 1] >= 0) - * (pairs_d[hole_indices][:, 1] < width) - ) - indices_b = ( - pairs_b[hole_indices][:, 0][nested_mask], - pairs_b[hole_indices][:, 1][nested_mask], - ) - indices_d = ( - pairs_d[hole_indices][:, 0][nested_mask], - pairs_d[hole_indices][:, 1][nested_mask], - ) - topo_cp_ref_map[indices_b] = likelihood[indices_d] - topo_mask[indices_b] = 1 - - indices_inv = ( - pairs_b[hole_indices][:, 0][mask], - pairs_b[hole_indices][:, 1][mask], - ) - topo_cp_ref_map[indices_inv] = fill_ref - topo_mask[indices_inv] = 1 - - return topo_cp_weight_map, topo_cp_ref_map, topo_mask - - -def set_topology_weights( - likelihood: np.ndarray, - topo_cp_weight_map: np.ndarray, - topo_cp_ref_map: np.ndarray, - topo_mask: np.ndarray, - bcp_lh: np.ndarray, - dcp_lh: np.ndarray, - idx_holes_to_fix: np.ndarray, - idx_holes_to_remove: np.ndarray, - height: int, - width: int, -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: - x = 0 - y = 0 - - if len(idx_holes_to_fix) > 0: - topo_cp_weight_map, topo_cp_ref_map, topo_mask = adjust_holes_to_fix( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_fix, - pairs=bcp_lh, - fill_weight=1, - fill_ref=0, - height=height, - width=width, - ) - topo_cp_weight_map, topo_cp_ref_map, topo_mask = adjust_holes_to_fix( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_fix, - pairs=dcp_lh, - fill_weight=1, - fill_ref=1, - height=height, - width=width, - ) - if len(idx_holes_to_remove) > 0: - ( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask, - ) = adjust_holes_to_remove( - likelihood, - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_remove, - pairs_b=bcp_lh, - pairs_d=dcp_lh, - fill_weight=1, - fill_ref=1, - height=height, - width=width, - ) - ( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask, - ) = adjust_holes_to_remove( - likelihood, - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_remove, - pairs_b=dcp_lh, - pairs_d=bcp_lh, - fill_weight=1, - fill_ref=0, - height=height, - width=width, - ) - - return topo_cp_weight_map, topo_cp_ref_map, topo_mask diff --git a/src/cultionet/models/cultionet.py b/src/cultionet/models/cultionet.py index a0418a3e..8a8c9d2d 100644 --- a/src/cultionet/models/cultionet.py +++ b/src/cultionet/models/cultionet.py @@ -6,248 +6,10 @@ from .. import nn as cunn from ..data.data import Data from ..enums import AttentionTypes, ModelTypes, ResBlockTypes -from .nunet import ResUNet3Psi, TowerUNet, UNet3Psi +from .nunet import TowerUNet from .temporal_transformer import TemporalTransformer -def scale_min_max( - x: torch.Tensor, - min_in: float, - max_in: float, - min_out: float, - max_out: float, -) -> torch.Tensor: - return (((max_out - min_out) * (x - min_in)) / (max_in - min_in)) + min_out - - -class GeoRefinement(nn.Module): - def __init__( - self, - in_features: int, - in_channels: int = 21, - n_hidden: int = 32, - out_channels: int = 2, - ): - super(GeoRefinement, self).__init__() - - # in_channels = - # StarRNN 3 + 2 - # Distance transform x4 - # Edge sigmoid x4 - # Crop softmax x4 - - self.gamma = nn.Parameter(torch.ones((1, out_channels, 1, 1))) - self.geo_attention = nn.Sequential( - cunn.ConvBlock2d( - in_channels=2, - out_channels=out_channels, - kernel_size=1, - padding=0, - add_activation=False, - ), - nn.Sigmoid(), - ) - - self.x_res_modules = nn.ModuleList( - [ - nn.Sequential( - cunn.ResidualConv( - in_channels=in_features, - out_channels=n_hidden, - dilation=2, - activation_type='SiLU', - ), - nn.Dropout(0.5), - ), - nn.Sequential( - cunn.ResidualConv( - in_channels=in_features, - out_channels=n_hidden, - dilation=3, - activation_type='SiLU', - ), - nn.Dropout(0.5), - ), - nn.Sequential( - cunn.ResidualConv( - in_channels=in_features, - out_channels=n_hidden, - dilation=4, - activation_type='SiLU', - ), - nn.Dropout(0.5), - ), - ] - ) - self.crop_res_modules = nn.ModuleList( - [ - nn.Sequential( - cunn.ResidualConv( - in_channels=in_channels, - out_channels=n_hidden, - dilation=2, - activation_type='SiLU', - ), - nn.Dropout(0.5), - ), - nn.Sequential( - cunn.ResidualConv( - in_channels=in_channels, - out_channels=n_hidden, - dilation=3, - activation_type='SiLU', - ), - nn.Dropout(0.5), - ), - nn.Sequential( - cunn.ResidualConv( - in_channels=in_channels, - out_channels=n_hidden, - dilation=4, - activation_type='SiLU', - ), - nn.Dropout(0.5), - ), - ] - ) - - self.fc = nn.Sequential( - cunn.ConvBlock2d( - in_channels=( - (n_hidden * len(self.x_res_modules)) - + (n_hidden * len(self.crop_res_modules)) - ), - out_channels=n_hidden, - kernel_size=1, - padding=0, - activation_type="SiLU", - ), - nn.Conv2d( - in_channels=n_hidden, - out_channels=out_channels, - kernel_size=1, - padding=0, - ), - ) - self.softmax = nn.Softmax(dim=1) - - def proba_to_logit(self, x: torch.Tensor) -> torch.Tensor: - return torch.log(x / (1.0 - x)) - - def forward( - self, predictions: T.Dict[str, torch.Tensor], data: Data - ) -> T.Dict[str, torch.Tensor]: - """A single forward pass. - - Edge and crop inputs should be probabilities - """ - height = ( - int(data.height) if data.batch is None else int(data.height[0]) - ) - width = int(data.width) if data.batch is None else int(data.width[0]) - batch_size = 1 if data.batch is None else data.batch.unique().size(0) - - latitude_norm = scale_min_max( - data.top - ((data.top - data.bottom) * 0.5), -90.0, 90.0, 0.0, 1.0 - ) - longitude_norm = scale_min_max( - data.left + ((data.right - data.left) * 0.5), - -180.0, - 180.0, - 0.0, - 1.0, - ) - lat_lon = torch.cat( - [ - latitude_norm.reshape(*latitude_norm.shape, 1, 1, 1), - longitude_norm.reshape(*longitude_norm.shape, 1, 1, 1), - ], - dim=1, - ) - geo_attention = self.geo_attention(lat_lon) - geo_attention = 1.0 + self.gamma * geo_attention - - crop_x = torch.cat( - [ - predictions["crop_star_l2"], - predictions["crop_star"], - predictions["dist"], - predictions["dist_3_1"], - predictions["dist_2_2"], - predictions["dist_1_3"], - predictions["edge"], - predictions["edge_3_1"], - predictions["edge_2_2"], - predictions["edge_1_3"], - predictions["crop"], - predictions["crop_3_1"], - predictions["crop_2_2"], - predictions["crop_1_3"], - ], - dim=1, - ) - x = torch.cat([m(crop_x) for m in self.x_res_modules], dim=1) - crop_x = torch.cat([m(crop_x) for m in self.crop_res_modules], dim=1) - - x = torch.cat([x, crop_x], dim=1) - x = self.softmax(self.fc(x) * geo_attention) - predictions["crop"] = x - - return predictions - - -class CropTypeFinal(nn.Module): - def __init__(self, in_channels: int, out_channels: int, out_classes: int): - super(CropTypeFinal, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.out_classes = out_classes - - self.conv1 = cunn.ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - activation_type="ReLU", - ) - layers1 = [ - cunn.ConvBlock2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type="ReLU", - ), - nn.Conv2d( - out_channels, - out_channels, - kernel_size=3, - padding=1, - bias=False, - ), - nn.BatchNorm2d(out_channels), - ] - self.seq = nn.Sequential(*layers1) - - layers_final = [ - nn.ReLU(inplace=False), - nn.Conv2d(out_channels, out_classes, kernel_size=1, padding=0), - ] - self.final = nn.Sequential(*layers_final) - - def forward( - self, x: torch.Tensor, crop_type_star: torch.Tensor - ) -> torch.Tensor: - out1 = self.conv1(x) - out = self.seq(out1) - out = out + out1 - out = self.final(out) - out = out + crop_type_star - - return out - - class CultioNet(nn.Module): """The cultionet model framework. @@ -287,7 +49,7 @@ def __init__( repeat_resa_kernel: bool = False, batchnorm_first: bool = False, ): - super(CultioNet, self).__init__() + super().__init__() self.in_channels = in_channels self.in_time = in_time @@ -328,17 +90,10 @@ def __init__( } assert model_type in ( - ModelTypes.UNET3PSI, - ModelTypes.RESUNET3PSI, - ModelTypes.TOWERUNET, + ModelTypes.TOWERUNET ), "The model type is not supported." - if model_type == ModelTypes.UNET3PSI: - self.mask_model = UNet3Psi(**unet3_kwargs) - elif model_type == ModelTypes.RESUNET3PSI: - self.mask_model = ResUNet3Psi(**unet3_kwargs) - else: - self.mask_model = TowerUNet(**unet3_kwargs) + self.mask_model = TowerUNet(**unet3_kwargs) def forward( self, batch: Data, training: bool = True diff --git a/src/cultionet/models/field_of_junctions.py b/src/cultionet/models/field_of_junctions.py deleted file mode 100644 index 42058167..00000000 --- a/src/cultionet/models/field_of_junctions.py +++ /dev/null @@ -1,481 +0,0 @@ -import typing as T - -import einops -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class FieldOfJunctions(nn.Module): - """ - Source: - https://github.com/dorverbin/fieldofjunctions - """ - - def __init__( - self, - in_channels: int, - height: int, - width: int, - patch_size: int = 8, - stride: int = 1, - nvals: int = 31, - delta: float = 0.05, - eta: float = 0.01, - ): - super(FieldOfJunctions, self).__init__() - - self.height = height - self.width = width - self.patch_size = patch_size - self.stride = stride - self.nvals = nvals - self.delta = delta - self.eta = eta - - self.reduce = nn.Sequential( - nn.Conv2d(in_channels, 3, kernel_size=1, padding=0, bias=False), - nn.BatchNorm2d(3), - nn.SiLU(), - ) - - # Number of patches (throughout the documentation H_patches and W_patches are denoted by H' and W' resp.) - self.h_patches = (height - patch_size) // stride + 1 - self.w_patches = (width - patch_size) // stride + 1 - - self.unfold = nn.Unfold(self.patch_size, stride=self.stride) - self.fold = nn.Fold( - output_size=[height, width], - kernel_size=self.patch_size, - stride=self.stride, - ) - - # Create local grid within each patch - meshy, meshx = torch.meshgrid( - [ - torch.linspace(-1.0, 1.0, self.patch_size), - torch.linspace(-1.0, 1.0, self.patch_size), - ], - indexing='ij', - ) - self.y = einops.rearrange(meshy, 'p k -> 1 p k 1 1') - self.x = einops.rearrange(meshx, 'p k -> 1 p k 1 1') - - # Values to search over in Algorithm 2: [0, 2pi) for angles, [-3, 3] for vertex position. - # self.angle_range = torch.linspace(0.0, 2 * np.pi, self.nvals + 1)[ - # : self.nvals - # ] - # self.x0_y0_range = torch.linspace(-3.0, 3.0, self.nvals) - - # Create pytorch variables for angles and vertex position for each patch - self.params = nn.Parameter( - torch.ones( - 1, 5, self.h_patches, self.w_patches, dtype=torch.float32 - ) - ) - - def forward(self, x: torch.Tensor) -> T.Dict[str, torch.Tensor]: - batch_size, in_channels, in_height, in_width = x.shape - - row_pad = 0 - col_pad = 0 - if (in_height, in_width) != (self.height, self.width): - row_pad = (self.height - in_height) // 2 - col_pad = (self.width - in_width) // 2 - x = F.pad( - x, - (row_pad, row_pad, col_pad, col_pad), - mode='constant', - value=0, - ) - - x = self.reduce(x) - - batch_size, num_channels, height, width = x.shape - - # Split image into overlapping patches, - # creating a tensor of shape [N, C, R, R, H', W'] - image_patches = einops.rearrange( - self.unfold(x), - 'b (c p k) (h w) -> b c p k h w', - p=self.patch_size, - k=self.patch_size, - h=self.h_patches, - w=self.w_patches, - ) - - # Compute number of patches containing each pixel: has shape [H, W] - num_patches = self.fold( - torch.ones( - batch_size, - self.patch_size**2, - self.h_patches * self.w_patches, - dtype=x.dtype, - device=x.device, - ), - ) - # Paper shape is (height x width) - num_patches = einops.rearrange(num_patches, 'b 1 h w -> b h w') - - self.y = self.y.to(device=x.device) - self.x = self.x.to(device=x.device) - # angle_range = self.angle_range.to(device=x.device) - # x0_y0_range = self.x0_y0_range.to(device=x.device) - - # params = self.params.detach() - - # # Run one step of Algorithm 2, sequentially improving each coordinate - # for i in range(5): - # # Repeat the set of parameters `nvals` times along 0th dimension - # params_query = params.repeat(self.nvals, 1, 1, 1) - # param_range = angle_range if i < 3 else x0_y0_range - # params_query[:, i, :, :] = params_query[ - # :, i, :, : - # ] + einops.rearrange(param_range, 'l -> l 1 1') - - # best_indices = self.get_best_indices( - # params_query, - # image_patches=image_patches, - # num_channels=num_channels, - # ) - - # # Update parameters - # params[0, i, :, :] = params_query[ - # einops.rearrange(best_indices, 'h w -> 1 h w'), - # i, - # einops.rearrange(torch.arange(self.h_patches), 'l -> 1 l 1'), - # einops.rearrange(torch.arange(self.w_patches), 'l -> 1 1 l'), - # ] - - # # Heuristic for accelerating convergence (not necessary but sometimes helps): - # # Update x0 and y0 along the three optimal angles (search over a line passing through current x0, y0) - # for i in range(3): - # params_query = params.repeat(self.nvals, 1, 1, 1) - # params_query[:, 3, :, :] = params[:, 3, :, :] + torch.cos( - # params[:, i, :, :] - # ) * x0_y0_range.view(-1, 1, 1) - # params_query[:, 4, :, :] = params[:, 4, :, :] + torch.sin( - # params[:, i, :, :] - # ) * x0_y0_range.view(-1, 1, 1) - # best_indices = self.get_best_indices( - # params_query, - # image_patches=image_patches, - # num_channels=num_channels, - # ) - - # # Update vertex positions of parameters - # for j in range(3, 5): - # params[:, j, :, :] = params_query[ - # einops.rearrange(best_indices, 'h w -> 1 h w'), - # j, - # einops.rearrange( - # torch.arange(self.h_patches), 'l -> 1 l 1' - # ), - # einops.rearrange( - # torch.arange(self.w_patches), 'l -> 1 1 l' - # ), - # ] - - # self.params.data = params.data - - # Compute distance functions, colors, and junction patches - distances, colors, patches = self.get_distances_and_patches( - self.params, - image_patches=image_patches, - num_channels=num_channels, - ) - # smoothed_image = self.local_to_global( - # patches, height, width, num_patches - # ) - local_boundaries = self.distances_to_boundaries(distances) - global_boundaries = self.local_to_global( - einops.rearrange( - local_boundaries, - '1 1 p k h w -> 1 1 1 p k h w', - ), - height, - width, - num_patches, - ) - # global_boundaries = self.final_boundaries(global_boundaries) - # # smoothed_image = self.final_image(smoothed_image) - - if row_pad > 0: - global_boundaries = global_boundaries[ - :, - :, - row_pad : row_pad + in_height, - col_pad : col_pad + in_width, - ] - - return { - "image_patches": image_patches, - "patches": patches, - "boundaries": global_boundaries, - } - - def distances_to_boundaries(self, dists: torch.Tensor) -> torch.Tensor: - """Compute boundary map for each patch, given distance functions. - - The width of the boundary is determined by opts.delta. - """ - # Find places where either distance transform is small, except where d1 > 0 and d2 < 0 - d1 = dists[:, 0:1, ...] - d2 = dists[:, 1:2, ...] - min_abs_distance = torch.where( - d1 < 0.0, - -d1, - torch.where(d2 < 0.0, torch.min(d1, -d2), torch.min(d1, d2)), - ) - - return 1.0 / (1.0 + (min_abs_distance / self.delta) ** 2) - - def local_to_global( - self, - patches: torch.Tensor, - height: int, - width: int, - num_patches: torch.Tensor, - ) -> torch.Tensor: - """Compute average value for each pixel over all patches containing it. - - For example, this can be used to compute the global boundary maps, or - the boundary-aware smoothed image. - """ - numerator = self.fold( - einops.rearrange(patches, 'b 1 c p k h w -> b (c p k) (h w)') - ) - denominator = einops.rearrange(num_patches, 'b h w -> b 1 h w') - - return numerator / denominator - - def get_best_indices( - self, - params: torch.Tensor, - image_patches: torch.Tensor, - num_channels: int, - ) -> torch.Tensor: - distances, colors, smooth_patches = self.get_distances_and_patches( - params, - image_patches=image_patches, - num_channels=num_channels, - ) - loss_per_patch = self.get_loss( - distances=distances, - colors=colors, - patches=smooth_patches, - image_patches=image_patches, - ) - best_indices = loss_per_patch.argmin(dim=0) - - return best_indices - - def get_distances_and_patches( - self, - params: torch.Tensor, - image_patches: torch.Tensor, - num_channels: int, - lmbda_color: float = 0.0, - ): - """Compute distance functions and piecewise-constant patches given - junction parameters.""" - # Get dists - distances = self.params_to_distances( - params - ) # shape [N, 2, R, R, H', W'] - - # Get wedge indicator functions - wedges = self.distances_to_indicators( - distances - ) # shape [N, 3, R, R, H', W'] - - # if lmbda_color >= 0 and self.global_image is not None: - # curr_global_image_patches = nn.Unfold(self.patch_size, stride=self.opts.stride)( - # self.global_image.detach()).view(1, num_channels, self.patch_size, self.patch_size, self.h_patches, self.w_patches) - - # numerator = ((self.img_patches + lmbda_color * - # curr_global_image_patches).unsqueeze(2) * wedges.unsqueeze(1)).sum(-3).sum(-3) - # denominator = (1.0 + lmbda_color) * wedges.sum(-3).sum(-3).unsqueeze(1) - - # colors = numerator / (denominator + 1e-10) - # else: - - numerator = einops.rearrange( - image_patches, 'b c p k h w -> b 1 c 1 p k h w' - ) * einops.rearrange(wedges, 'n c p k h w -> 1 n 1 c p k h w') - numerator = einops.reduce( - numerator, 'b n c l p k h w -> b n c l h w', 'sum' - ) - denominator = ( - einops.reduce(wedges, 'n c p k h w -> 1 n 1 c h w', 'sum') + 1e-10 - ) - colors = numerator / denominator - - # Fill wedges with optimal colors - patches = einops.rearrange( - wedges, 'n c p k h w -> 1 n 1 c p k h w' - ) * einops.rearrange(colors, 'b n c l h w -> b n c l 1 1 h w') - patches = einops.reduce( - patches, 'b n c l p k h w -> b n c p k h w', 'sum' - ) - - return distances, colors, patches - - def params_to_distances( - self, params: torch.Tensor, tau=1e-1 - ) -> torch.Tensor: - """Compute distance functions from field of junctions.""" - x0 = ( - params[:, 3, :, :].unsqueeze(1).unsqueeze(1) - ) # shape [N, 1, 1, H', W'] - y0 = ( - params[:, 4, :, :].unsqueeze(1).unsqueeze(1) - ) # shape [N, 1, 1, H', W'] - - # Sort so angle1 <= angle2 <= angle3 (mod 2pi) - angles = torch.remainder(params[:, :3, :, :], 2 * np.pi) - angles = torch.sort(angles, dim=1)[0] - - angle1 = ( - angles[:, 0, :, :].unsqueeze(1).unsqueeze(1) - ) # shape [N, 1, 1, H', W'] - angle2 = ( - angles[:, 1, :, :].unsqueeze(1).unsqueeze(1) - ) # shape [N, 1, 1, H', W'] - angle3 = ( - angles[:, 2, :, :].unsqueeze(1).unsqueeze(1) - ) # shape [N, 1, 1, H', W'] - - # Define another angle halfway between angle3 and angle1, clockwise from angle3 - # This isn't critical but it seems a bit more stable for computing gradients - angle4 = 0.5 * (angle1 + angle3) + torch.where( - torch.remainder(0.5 * (angle1 - angle3), 2 * np.pi) >= np.pi, - torch.ones_like(angle1) * np.pi, - torch.zeros_like(angle1), - ) - - def _g(dtheta): - # Map from [0, 2pi] to [-1, 1] - return (dtheta / np.pi - 1.0) ** 35 - - # Compute the two distance functions - sgn42 = torch.where( - torch.remainder(angle2 - angle4, 2 * np.pi) < np.pi, - torch.ones_like(angle2), - -torch.ones_like(angle2), - ) - tau42 = _g(torch.remainder(angle2 - angle4, 2 * np.pi)) * tau - - dist42 = ( - sgn42 - * torch.min( - sgn42 - * ( - -torch.sin(angle4) * (self.x - x0) - + torch.cos(angle4) * (self.y - y0) - ), - -sgn42 - * ( - -torch.sin(angle2) * (self.x - x0) - + torch.cos(angle2) * (self.y - y0) - ), - ) - + tau42 - ) - - sgn13 = torch.where( - torch.remainder(angle3 - angle1, 2 * np.pi) < np.pi, - torch.ones_like(angle3), - -torch.ones_like(angle3), - ) - tau13 = _g(torch.remainder(angle3 - angle1, 2 * np.pi)) * tau - dist13 = ( - sgn13 - * torch.min( - sgn13 - * ( - -torch.sin(angle1) * (self.x - x0) - + torch.cos(angle1) * (self.y - y0) - ), - -sgn13 - * ( - -torch.sin(angle3) * (self.x - x0) - + torch.cos(angle3) * (self.y - y0) - ), - ) - + tau13 - ) - - return torch.stack([dist13, dist42], dim=1) - - def distances_to_indicators(self, dists: torch.Tensor) -> torch.Tensor: - """Computes the indicator functions u_1, u_2, u_3 from the distance - functions d_{13}, d_{12}""" - # Apply smooth Heaviside function to distance functions - hdists = 0.5 * (1.0 + (2.0 / np.pi) * torch.atan(dists / self.eta)) - - # Convert Heaviside functions into wedge indicator functions - return torch.stack( - [ - 1.0 - hdists[:, 0, :, :, :, :], - hdists[:, 0, :, :, :, :] * (1.0 - hdists[:, 1, :, :, :, :]), - hdists[:, 0, :, :, :, :] * hdists[:, 1, :, :, :, :], - ], - dim=1, - ) - - def get_loss( - self, - distances: torch.Tensor, - colors: torch.Tensor, - patches: torch.Tensor, - image_patches: torch.Tensor, - lmbda_boundary: float = 0.0, - lmbda_color: float = 0.0, - ): - """Compute the objective of our model (see Equation 8 of the paper).""" - - # Compute negative log-likelihood for each patch (shape [N, H', W']) - loss_per_patch = einops.reduce( - ( - einops.rearrange(image_patches, 'b c p k h w -> b 1 c p k h w') - - patches - ) - ** 2, - 'b n c p k h w -> b n c h w', - 'mean', - ) - loss_per_patch = einops.reduce( - loss_per_patch, 'b n c h w -> b n h w', 'sum' - ) - # Reduce to the batch mean - loss_per_patch = einops.reduce( - loss_per_patch, 'b n h w -> n h w', 'mean' - ) - - return loss_per_patch - - -if __name__ == '__main__': - batch_size = 2 - num_channels = 3 - height = 100 - width = 100 - - x = torch.rand( - (batch_size, num_channels, height, width), - dtype=torch.float32, - ) - - foj = FieldOfJunctions( - in_channels=num_channels, - height=110, - width=110, - patch_size=8, - stride=1, - nvals=31, - delta=0.05, - eta=0.01, - ) - out = foj(x) diff --git a/src/cultionet/models/maskcrnn.py b/src/cultionet/models/maskcrnn.py index 35981a32..f33cd3e3 100644 --- a/src/cultionet/models/maskcrnn.py +++ b/src/cultionet/models/maskcrnn.py @@ -225,7 +225,7 @@ def __init__( min_image_size: int = 800, max_image_size: int = 1333, ) -> None: - super(BFasterRCNN, self).__init__() + super().__init__() if sizes is None: sizes = (32, 64, 128, 256, 512) diff --git a/src/cultionet/models/nunet.py b/src/cultionet/models/nunet.py index 98a78d9b..280474e2 100644 --- a/src/cultionet/models/nunet.py +++ b/src/cultionet/models/nunet.py @@ -19,7 +19,7 @@ class DepthwiseSeparableConv(nn.Module): def __init__( self, in_channels: int, hidden_channels: int, out_channels: int ): - super(DepthwiseSeparableConv, self).__init__() + super().__init__() self.separable = nn.Sequential( nn.Conv2d( @@ -48,7 +48,7 @@ def __init__( num_time: int, activation_type: str = 'SiLU', ): - super(ReduceTimeToOne, self).__init__() + super().__init__() self.conv = nn.Sequential( nn.Conv3d( @@ -83,7 +83,7 @@ def __init__( activation_type: str, trend_kernel_size: int = 5, ): - super(PreUnet3Psi, self).__init__() + super().__init__() self.reduce_time_init = ReduceTimeToOne( in_channels=in_channels, @@ -147,530 +147,6 @@ def forward( return encoded -class PostUNet3Psi(nn.Module): - def __init__( - self, - up_channels: int, - num_classes: int, - mask_activation: T.Callable, - deep_sup_dist: T.Optional[bool] = False, - deep_sup_edge: T.Optional[bool] = False, - deep_sup_mask: T.Optional[bool] = False, - ): - super(PostUNet3Psi, self).__init__() - - self.deep_sup_dist = deep_sup_dist - self.deep_sup_edge = deep_sup_edge - self.deep_sup_mask = deep_sup_mask - - self.up = cunn.UpSample() - - self.final_dist = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - nn.Sigmoid(), - ) - self.final_edge = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - cunn.SigmoidCrisp(), - ) - self.final_mask = nn.Sequential( - nn.Conv2d(up_channels, num_classes, kernel_size=1, padding=0), - mask_activation, - ) - if self.deep_sup_dist: - self.final_dist_3_1 = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - nn.Sigmoid(), - ) - self.final_dist_2_2 = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - nn.Sigmoid(), - ) - self.final_dist_1_3 = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - nn.Sigmoid(), - ) - if self.deep_sup_edge: - self.final_edge_3_1 = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - cunn.SigmoidCrisp(), - ) - self.final_edge_2_2 = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - cunn.SigmoidCrisp(), - ) - self.final_edge_1_3 = nn.Sequential( - nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - cunn.SigmoidCrisp(), - ) - if self.deep_sup_mask: - self.final_mask_3_1 = nn.Sequential( - nn.Conv2d(up_channels, num_classes, kernel_size=1, padding=0), - mask_activation, - ) - self.final_mask_2_2 = nn.Sequential( - nn.Conv2d(up_channels, num_classes, kernel_size=1, padding=0), - mask_activation, - ) - self.final_mask_1_3 = nn.Sequential( - nn.Conv2d(up_channels, num_classes, kernel_size=1, padding=0), - mask_activation, - ) - - def forward( - self, - out_0_4: T.Dict[str, torch.Tensor], - out_3_1: T.Dict[str, torch.Tensor], - out_2_2: T.Dict[str, torch.Tensor], - out_1_3: T.Dict[str, torch.Tensor], - ) -> T.Dict[str, torch.Tensor]: - dist = self.final_dist(out_0_4["dist"]) - edge = self.final_edge(out_0_4["edge"]) - mask = self.final_mask(out_0_4["mask"]) - - out = { - "dist": dist, - "edge": edge, - "mask": mask, - "dist_3_1": None, - "dist_2_2": None, - "dist_1_3": None, - "edge_3_1": None, - "edge_2_2": None, - "edge_1_3": None, - "mask_3_1": None, - "mask_2_2": None, - "mask_1_3": None, - } - - if self.deep_sup_dist: - out["dist_3_1"] = self.final_dist_3_1( - self.up(out_3_1["dist"], size=dist.shape[-2:], mode="bilinear") - ) - out["dist_2_2"] = self.final_dist_2_2( - self.up(out_2_2["dist"], size=dist.shape[-2:], mode="bilinear") - ) - out["dist_1_3"] = self.final_dist_1_3( - self.up(out_1_3["dist"], size=dist.shape[-2:], mode="bilinear") - ) - if self.deep_sup_edge: - out["edge_3_1"] = self.final_edge_3_1( - self.up(out_3_1["edge"], size=edge.shape[-2:], mode="bilinear") - ) - out["edge_2_2"] = self.final_edge_2_2( - self.up(out_2_2["edge"], size=edge.shape[-2:], mode="bilinear") - ) - out["edge_1_3"] = self.final_edge_1_3( - self.up(out_1_3["edge"], size=edge.shape[-2:], mode="bilinear") - ) - if self.deep_sup_mask: - out["mask_3_1"] = self.final_mask_3_1( - self.up(out_3_1["mask"], size=mask.shape[-2:], mode="bilinear") - ) - out["mask_2_2"] = self.final_mask_2_2( - self.up(out_2_2["mask"], size=mask.shape[-2:], mode="bilinear") - ) - out["mask_1_3"] = self.final_mask_1_3( - self.up(out_1_3["mask"], size=mask.shape[-2:], mode="bilinear") - ) - - return out - - -class UNet3Psi(nn.Module): - """UNet+++ with Psi-Net. - - References: - https://arxiv.org/ftp/arxiv/papers/2004/2004.08790.pdf - https://arxiv.org/abs/1902.04099 - https://github.com/Bala93/Multi-task-deep-network - """ - - def __init__( - self, - in_channels: int, - in_time: int, - in_encoding_channels: int, - hidden_channels: int = 32, - num_classes: int = 2, - dilation: int = 2, - activation_type: str = "SiLU", - deep_sup_dist: T.Optional[bool] = False, - deep_sup_edge: T.Optional[bool] = False, - deep_sup_mask: T.Optional[bool] = False, - mask_activation: T.Union[nn.Softmax, nn.Sigmoid] = nn.Softmax(dim=1), - ): - super(UNet3Psi, self).__init__() - - channels = [ - hidden_channels, - hidden_channels * 2, - hidden_channels * 4, - hidden_channels * 8, - hidden_channels * 16, - ] - up_channels = int(channels[0] * 5) - - self.pre_unet = PreUnet3Psi( - in_channels=in_channels, - channels=channels, - activation_type=activation_type, - ) - - # Inputs = - # Reduced time dimensions - # Reduced channels (x2) for mean and max - # Input filters for transformer hidden logits - self.conv0_0 = cunn.SingleConv( - in_channels=( - in_time - + int(channels[0] * 4) - + in_encoding_channels - # Peak kernels and Trend kernels - + in_time - ), - out_channels=channels[0], - activation_type=activation_type, - ) - self.conv1_0 = cunn.PoolConv( - channels[0], - channels[1], - double_dilation=dilation, - activation_type=activation_type, - ) - self.conv2_0 = cunn.PoolConv( - channels[1], - channels[2], - double_dilation=dilation, - activation_type=activation_type, - ) - self.conv3_0 = cunn.PoolConv( - channels[2], - channels[3], - double_dilation=dilation, - activation_type=activation_type, - ) - self.conv4_0 = cunn.PoolConv( - channels[3], - channels[4], - double_dilation=dilation, - activation_type=activation_type, - ) - - # Connect 3 - self.convs_3_1 = cunn.UNet3_3_1( - channels=channels, - up_channels=up_channels, - dilations=[dilation], - activation_type=activation_type, - ) - self.convs_2_2 = cunn.UNet3_2_2( - channels=channels, - up_channels=up_channels, - dilations=[dilation], - activation_type=activation_type, - ) - self.convs_1_3 = cunn.UNet3_1_3( - channels=channels, - up_channels=up_channels, - dilations=[dilation], - activation_type=activation_type, - ) - self.convs_0_4 = cunn.UNet3_0_4( - channels=channels, - up_channels=up_channels, - dilations=[dilation], - activation_type=activation_type, - ) - - self.post_unet = PostUNet3Psi( - up_channels=up_channels, - num_classes=num_classes, - mask_activation=mask_activation, - deep_sup_dist=deep_sup_dist, - deep_sup_edge=deep_sup_edge, - deep_sup_mask=deep_sup_mask, - ) - - # Initialise weights - self.apply(init_conv_weights) - - def forward( - self, x: torch.Tensor, temporal_encoding: torch.Tensor - ) -> T.Dict[str, T.Union[None, torch.Tensor]]: - # Inputs shape is (B x C X T|D x H x W) - h = self.pre_unet(x, temporal_encoding) - # h shape is (B x C x H x W) - # Backbone - # 1/1 - x0_0 = self.conv0_0(h) - # 1/2 - x1_0 = self.conv1_0(x0_0) - # 1/4 - x2_0 = self.conv2_0(x1_0) - # 1/8 - x3_0 = self.conv3_0(x2_0) - # 1/16 - x4_0 = self.conv4_0(x3_0) - - # 1/8 connection - out_3_1 = self.convs_3_1( - x0_0=x0_0, x1_0=x1_0, x2_0=x2_0, x3_0=x3_0, x4_0=x4_0 - ) - # 1/4 connection - out_2_2 = self.convs_2_2( - x0_0=x0_0, - x1_0=x1_0, - x2_0=x2_0, - h3_1_dist=out_3_1["dist"], - h3_1_edge=out_3_1["edge"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, - ) - # 1/2 connection - out_1_3 = self.convs_1_3( - x0_0=x0_0, - x1_0=x1_0, - h2_2_dist=out_2_2["dist"], - h3_1_dist=out_3_1["dist"], - h2_2_edge=out_2_2["edge"], - h3_1_edge=out_3_1["edge"], - h2_2_mask=out_2_2["mask"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, - ) - # 1/1 connection - out_0_4 = self.convs_0_4( - x0_0=x0_0, - h1_3_dist=out_1_3["dist"], - h2_2_dist=out_2_2["dist"], - h3_1_dist=out_3_1["dist"], - h1_3_edge=out_1_3["edge"], - h2_2_edge=out_2_2["edge"], - h3_1_edge=out_3_1["edge"], - h1_3_mask=out_1_3["mask"], - h2_2_mask=out_2_2["mask"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, - ) - - out = self.post_unet( - out_0_4=out_0_4, out_3_1=out_3_1, out_2_2=out_2_2, out_1_3=out_1_3 - ) - - return out - - -class ResUNet3Psi(nn.Module): - """Residual UNet+++ with Psi-Net (Multi-head streams) and Attention. - - References: - https://arxiv.org/ftp/arxiv/papers/2004/2004.08790.pdf - https://arxiv.org/abs/1902.04099 - https://github.com/Bala93/Multi-task-deep-network - https://github.com/hamidriasat/UNet-3-Plus - """ - - def __init__( - self, - in_channels: int, - in_time: int, - hidden_channels: int = 32, - num_classes: int = 2, - dilations: T.Sequence[int] = None, - activation_type: str = "SiLU", - res_block_type: str = ResBlockTypes.RES, - attention_weights: T.Optional[str] = None, - deep_sup_dist: T.Optional[bool] = False, - deep_sup_edge: T.Optional[bool] = False, - deep_sup_mask: T.Optional[bool] = False, - mask_activation: T.Union[nn.Softmax, nn.Sigmoid] = nn.Softmax(dim=1), - ): - super(ResUNet3Psi, self).__init__() - - if dilations is None: - dilations = [2] - if attention_weights is None: - attention_weights = "spatial_channel" - - channels = [ - hidden_channels, - hidden_channels * 2, - hidden_channels * 4, - hidden_channels * 8, - hidden_channels * 16, - ] - up_channels = int(channels[0] * 5) - - self.pre_unet = PreUnet3Psi( - in_channels=in_channels, - in_time=in_time, - channels=channels, - activation_type=activation_type, - ) - - # Inputs = - # Reduced time dimensions - # Reduced channels (x2) for mean and max - # Input filters for RNN hidden logits - if res_block_type.lower() == ResBlockTypes.RES: - self.conv0_0 = cunn.ResidualConv( - in_channels=channels[0], - out_channels=channels[0], - dilation=dilations[0], - activation_type=activation_type, - attention_weights=attention_weights, - ) - else: - self.conv0_0 = cunn.ResidualAConv( - in_channels=channels[0], - out_channels=channels[0], - dilations=dilations, - activation_type=activation_type, - attention_weights=attention_weights, - ) - self.conv1_0 = cunn.PoolResidualConv( - channels[0], - channels[1], - dilations=dilations, - attention_weights=attention_weights, - res_block_type=res_block_type, - ) - self.conv2_0 = cunn.PoolResidualConv( - channels[1], - channels[2], - dilations=dilations, - activation_type=activation_type, - attention_weights=attention_weights, - res_block_type=res_block_type, - ) - self.conv3_0 = cunn.PoolResidualConv( - channels[2], - channels[3], - dilations=dilations, - activation_type=activation_type, - attention_weights=attention_weights, - res_block_type=res_block_type, - ) - self.conv4_0 = cunn.PoolResidualConv( - channels[3], - channels[4], - dilations=dilations, - activation_type=activation_type, - attention_weights=attention_weights, - res_block_type=res_block_type, - ) - - # Connect 3 - self.convs_3_1 = cunn.ResUNet3_3_1( - channels=channels, - up_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - res_block_type=res_block_type, - ) - self.convs_2_2 = cunn.ResUNet3_2_2( - channels=channels, - up_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - res_block_type=res_block_type, - ) - self.convs_1_3 = cunn.ResUNet3_1_3( - channels=channels, - up_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - res_block_type=res_block_type, - ) - self.convs_0_4 = cunn.ResUNet3_0_4( - channels=channels, - up_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - res_block_type=res_block_type, - ) - - self.post_unet = PostUNet3Psi( - up_channels=up_channels, - num_classes=num_classes, - mask_activation=mask_activation, - deep_sup_dist=deep_sup_dist, - deep_sup_edge=deep_sup_edge, - deep_sup_mask=deep_sup_mask, - ) - - # Initialise weights - self.apply(init_conv_weights) - - def forward( - self, - x: torch.Tensor, - temporal_encoding: T.Optional[torch.Tensor] = None, - ) -> T.Dict[str, T.Union[None, torch.Tensor]]: - # Inputs shape is (B x C X T|D x H x W) - h = self.pre_unet(x, temporal_encoding=temporal_encoding) - # h shape is (B x C x H x W) - - # Backbone - # 1/1 - x0_0 = self.conv0_0(h) - # 1/2 - x1_0 = self.conv1_0(x0_0) - # 1/4 - x2_0 = self.conv2_0(x1_0) - # 1/8 - x3_0 = self.conv3_0(x2_0) - # 1/16 - x4_0 = self.conv4_0(x3_0) - - # 1/8 connection - out_3_1 = self.convs_3_1( - side=x3_0, - down=x4_0, - pools=[x0_0, x1_0, x2_0], - ) - # 1/4 connection - out_2_2 = self.convs_2_2( - side=x2_0, - dist_down=[out_3_1["dist"]], - edge_down=[out_3_1["edge"]], - mask_down=[out_3_1["mask"]], - down=x4_0, - pools=[x0_0, x1_0], - ) - # 1/2 connection - out_1_3 = self.convs_1_3( - side=x1_0, - dist_down=[out_3_1["dist"], out_2_2["dist"]], - edge_down=[out_3_1["edge"], out_2_2["edge"]], - mask_down=[out_3_1["mask"], out_2_2["mask"]], - down=x4_0, - pools=[x0_0], - ) - # 1/1 connection - out_0_4 = self.convs_0_4( - side=x0_0, - dist_down=[out_3_1["dist"], out_2_2["dist"], out_1_3['dist']], - edge_down=[out_3_1["edge"], out_2_2["edge"], out_1_3['edge']], - mask_down=[out_3_1["mask"], out_2_2["mask"], out_1_3['mask']], - down=x4_0, - ) - - out = self.post_unet( - out_0_4=out_0_4, - out_3_1=out_3_1, - out_2_2=out_2_2, - out_1_3=out_1_3, - ) - - return out - - class TowerUNet(nn.Module): """Tower U-Net.""" @@ -693,7 +169,7 @@ def __init__( batchnorm_first: bool = False, concat_resid: bool = False, ): - super(TowerUNet, self).__init__() + super().__init__() if dilations is None: dilations = [1, 2] @@ -831,329 +307,3 @@ def forward( out.update(out_c) return out - - -class _TowerUNet(nn.Module): - """Tower U-Net.""" - - def __init__( - self, - in_channels: int, - in_time: int, - hidden_channels: int = 64, - num_classes: int = 2, - dilations: T.Sequence[int] = None, - activation_type: str = "SiLU", - dropout: float = 0.0, - res_block_type: str = ResBlockTypes.RESA, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - mask_activation: T.Union[nn.Softmax, nn.Sigmoid] = nn.Softmax(dim=1), - deep_supervision: bool = False, - pool_attention: bool = False, - pool_by_max: bool = False, - repeat_resa_kernel: bool = False, - batchnorm_first: bool = False, - concat_resid: bool = False, - ): - super(TowerUNet, self).__init__() - - if dilations is None: - dilations = [1, 2] - - self.deep_supervision = deep_supervision - - channels = [ - hidden_channels, # a - hidden_channels * 2, # b - hidden_channels * 4, # c - hidden_channels * 8, # d - ] - up_channels = int(hidden_channels * len(channels)) - - self.pre_unet = PreUnet3Psi( - in_channels=in_channels, - in_time=in_time, - out_channels=channels[0], - activation_type=activation_type, - ) - - # Backbone layers - backbone_kwargs = dict( - dropout=dropout, - activation_type=activation_type, - res_block_type=res_block_type, - batchnorm_first=batchnorm_first, - pool_by_max=pool_by_max, - concat_resid=concat_resid, - natten_num_heads=8, - natten_kernel_size=3, - natten_dilation=1, - natten_attn_drop=dropout, - natten_proj_drop=dropout, - ) - self.down_a = cunn.PoolResidualConv( - in_channels=channels[0], - out_channels=channels[0], - dilations=dilations, - repeat_resa_kernel=repeat_resa_kernel, - pool_first=False, - attention_weights=attention_weights if pool_attention else None, - **backbone_kwargs, - ) - self.down_b = cunn.PoolResidualConv( - in_channels=channels[0], - out_channels=channels[1], - dilations=dilations, - repeat_resa_kernel=repeat_resa_kernel, - attention_weights=attention_weights if pool_attention else None, - **backbone_kwargs, - ) - self.down_c = cunn.PoolResidualConv( - channels[1], - channels[2], - dilations=dilations[:2], - repeat_resa_kernel=repeat_resa_kernel, - attention_weights=attention_weights if pool_attention else None, - **backbone_kwargs, - ) - self.down_d = cunn.PoolResidualConv( - channels[2], - channels[3], - kernel_size=1, - num_blocks=1, - dilations=[1], - repeat_resa_kernel=repeat_resa_kernel, - attention_weights=None, - **backbone_kwargs, - ) - - # Up layers - up_kwargs = dict( - activation_type=activation_type, - res_block_type=res_block_type, - repeat_resa_kernel=repeat_resa_kernel, - batchnorm_first=batchnorm_first, - concat_resid=concat_resid, - natten_num_heads=8, - natten_attn_drop=dropout, - natten_proj_drop=dropout, - ) - self.over_d = cunn.UNetUpBlock( - in_channels=channels[3], - out_channels=up_channels, - kernel_size=1, - num_blocks=1, - dilations=[1], - attention_weights=None, - resample_up=False, - **up_kwargs, - ) - self.up_cu = cunn.UNetUpBlock( - in_channels=up_channels, - out_channels=up_channels, - attention_weights=attention_weights, - dilations=dilations[:2], - natten_kernel_size=3, - natten_dilation=1, - **up_kwargs, - ) - self.up_bu = cunn.UNetUpBlock( - in_channels=up_channels, - out_channels=up_channels, - attention_weights=attention_weights, - dilations=dilations, - natten_kernel_size=5, - natten_dilation=2, - **up_kwargs, - ) - self.up_au = cunn.UNetUpBlock( - in_channels=up_channels, - out_channels=up_channels, - attention_weights=attention_weights, - dilations=dilations, - natten_kernel_size=7, - natten_dilation=3, - **up_kwargs, - ) - - # Towers - tower_kwargs = dict( - up_channels=up_channels, - out_channels=up_channels, - attention_weights=attention_weights, - activation_type=activation_type, - res_block_type=res_block_type, - repeat_resa_kernel=repeat_resa_kernel, - batchnorm_first=batchnorm_first, - concat_resid=concat_resid, - natten_num_heads=8, - natten_attn_drop=dropout, - natten_proj_drop=dropout, - ) - self.tower_c = cunn.TowerUNetBlock( - backbone_side_channels=channels[2], - backbone_down_channels=channels[3], - dilations=dilations[:2], - natten_kernel_size=3, - natten_dilation=1, - **tower_kwargs, - ) - self.tower_b = cunn.TowerUNetBlock( - backbone_side_channels=channels[1], - backbone_down_channels=channels[2], - tower=True, - dilations=dilations, - natten_kernel_size=5, - natten_dilation=2, - **tower_kwargs, - ) - self.tower_a = cunn.TowerUNetBlock( - backbone_side_channels=channels[0], - backbone_down_channels=channels[1], - tower=True, - dilations=dilations, - natten_kernel_size=7, - natten_dilation=3, - **tower_kwargs, - ) - - self.final_a = cunn.TowerUNetFinal( - in_channels=up_channels, - num_classes=num_classes, - mask_activation=mask_activation, - activation_type=activation_type, - ) - - if self.deep_supervision: - self.final_b = cunn.TowerUNetFinal( - in_channels=up_channels, - num_classes=num_classes, - mask_activation=mask_activation, - activation_type=activation_type, - resample_factor=2, - ) - self.final_c = cunn.TowerUNetFinal( - in_channels=up_channels, - num_classes=num_classes, - mask_activation=mask_activation, - activation_type=activation_type, - resample_factor=4, - ) - - # Initialize weights - self.apply(init_conv_weights) - - def forward( - self, - x: torch.Tensor, - temporal_encoding: T.Optional[torch.Tensor] = None, - training: bool = True, - ) -> T.Dict[str, torch.Tensor]: - - """Forward pass. - - Parameters - ========== - x - Shaped (B x C x T x H x W) - temporal_encoding - Shaped (B x C x H X W) - """ - - # Initial temporal reduction and convolutions to - # hidden dimensions - embeddings = self.pre_unet(x, temporal_encoding=temporal_encoding) - - # Backbone - x_a = self.down_a(embeddings) # 1/1 of input - x_b = self.down_b(x_a) # 1/2 of input - x_c = self.down_c(x_b) # 1/4 of input - x_d = self.down_d(x_c) # 1/8 of input - - x_du = self.over_d(x_d, size=x_d.shape[-2:]) - - # Up - x_cu = self.up_cu(x_du, size=x_c.shape[-2:]) - x_bu = self.up_bu(x_cu, size=x_b.shape[-2:]) - x_au = self.up_au(x_bu, size=x_a.shape[-2:]) - - # Central towers - x_tower_c = self.tower_c( - backbone_side=x_c, - backbone_down=x_d, - decode_side=x_cu, - decode_down=x_du, - ) - x_tower_b = self.tower_b( - backbone_side=x_b, - backbone_down=x_c, - decode_side=x_bu, - decode_down=x_cu, - tower_down=x_tower_c, - ) - x_tower_a = self.tower_a( - backbone_side=x_a, - backbone_down=x_b, - decode_side=x_au, - decode_down=x_bu, - tower_down=x_tower_b, - ) - - # Final outputs - out = self.final_a(x_tower_a) - - if training and self.deep_supervision: - out_c = self.final_c( - x_tower_c, - size=x_tower_a.shape[-2:], - suffix="_c", - ) - out_b = self.final_b( - x_tower_b, - size=x_tower_a.shape[-2:], - suffix="_b", - ) - - out.update(out_b) - out.update(out_c) - - return out - - -if __name__ == '__main__': - batch_size = 2 - num_channels = 3 - hidden_channels = 32 - num_head = 8 - num_time = 13 - height = 100 - width = 100 - - x = torch.rand( - (batch_size, num_channels, num_time, height, width), - dtype=torch.float32, - ) - logits_hidden = torch.rand( - (batch_size, hidden_channels, height, width), dtype=torch.float32 - ) - - model = TowerUNet( - in_channels=num_channels, - in_time=num_time, - hidden_channels=hidden_channels, - dilations=[1, 2], - dropout=0.2, - res_block_type=ResBlockTypes.RESA, - attention_weights=AttentionTypes.SPATIAL_CHANNEL, - deep_supervision=False, - pool_attention=False, - pool_first=False, - repeat_resa_kernel=False, - batchnorm_first=True, - ) - - logits = model(x, temporal_encoding=logits_hidden) - - assert logits['dist'].shape == (batch_size, 1, height, width) - assert logits['edge'].shape == (batch_size, 1, height, width) - assert logits['mask'].shape == (batch_size, 2, height, width) diff --git a/src/cultionet/models/temporal_transformer.py b/src/cultionet/models/temporal_transformer.py index c582db7f..c5cfc570 100644 --- a/src/cultionet/models/temporal_transformer.py +++ b/src/cultionet/models/temporal_transformer.py @@ -30,7 +30,7 @@ def __init__( scale: float, dropout: float = 0.1, ): - super(ScaledDotProductAttention, self).__init__() + super().__init__() self.dropout = None if dropout > 0: @@ -64,7 +64,7 @@ class MultiHeadAttention(nn.Module): """ def __init__(self, d_model: int, num_head: int, dropout: float = 0.1): - super(MultiHeadAttention, self).__init__() + super().__init__() self.num_head = num_head d_k = d_model // num_head @@ -118,7 +118,7 @@ def forward( class PositionWiseFeedForward(nn.Module): def __init__(self, d_model: int, hidden_channels: int): - super(PositionWiseFeedForward, self).__init__() + super().__init__() self.fc1 = nn.Linear(d_model, hidden_channels) self.fc2 = nn.Linear(hidden_channels, d_model) @@ -135,7 +135,7 @@ def __init__( num_head: int, dropout: float = 0.1, ): - super(EncoderLayer, self).__init__() + super().__init__() self.self_attn = MultiHeadAttention( d_model=d_model, num_head=num_head, dropout=dropout @@ -162,7 +162,7 @@ def __init__( num_layers: int, dropout: float = 0.1, ): - super(Transformer, self).__init__() + super().__init__() self.encoder_layers = nn.ModuleList( [ @@ -190,7 +190,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class InLayer(nn.Module): def __init__(self, in_channels: int, out_channels: int): - super(InLayer, self).__init__() + super().__init__() self.seq = nn.Sequential( nn.Conv3d( @@ -217,7 +217,7 @@ def __init__( hidden_channels: int, out_channels: int, ): - super(InBlock, self).__init__() + super().__init__() self.seq = nn.Sequential( InLayer(in_channels=in_channels, out_channels=hidden_channels), @@ -235,7 +235,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Identity(nn.Module): def __init__(self): - super(Identity, self).__init__() + super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return x @@ -251,7 +251,7 @@ def __init__( activation_type: str = "SiLU", final_activation: Callable = nn.Softmax(dim=1), ): - super(TemporalTransformerFinal, self).__init__() + super().__init__() # Level 2 level (non-crop; crop) self.final_l2 = cunn.FinalConv2dDropout( @@ -316,7 +316,7 @@ def __init__( to project them into a feature space of dimension d_model. time_scaler (int): Period to use for the positional encoding. """ - super(TemporalTransformer, self).__init__() + super().__init__() self.d_model = d_model self.num_classes_l2 = num_classes_l2 @@ -432,28 +432,3 @@ def forward(self, x: torch.Tensor) -> dict: "l3": encoded["l3"], "encoded": encoded["hidden"], } - - -if __name__ == '__main__': - batch_size = 2 - num_channels = 3 - hidden_channels = 64 - num_head = 8 - d_model = 128 - in_time = 13 - height = 100 - width = 100 - - x = torch.rand( - (batch_size, num_channels, in_time, height, width), - dtype=torch.float32, - ) - - model = TemporalTransformer( - in_channels=num_channels, - hidden_channels=hidden_channels, - num_head=num_head, - d_model=d_model, - in_time=in_time, - ) - output = model(x) diff --git a/src/cultionet/nn/__init__.py b/src/cultionet/nn/__init__.py index 495ad973..7251f844 100644 --- a/src/cultionet/nn/__init__.py +++ b/src/cultionet/nn/__init__.py @@ -12,20 +12,11 @@ from .modules.kernels import Peaks3d, Trend3d from .modules.reshape import UpSample from .modules.unet_parts import ( - ResELUNetPsiBlock, - ResUNet3_0_4, - ResUNet3_1_3, - ResUNet3_2_2, - ResUNet3_3_1, TowerUNetBlock, TowerUNetDecoder, TowerUNetEncoder, TowerUNetFinal, TowerUNetFusion, - UNet3_0_4, - UNet3_1_3, - UNet3_2_2, - UNet3_3_1, UNetUpBlock, ) @@ -49,13 +40,4 @@ 'TowerUNetEncoder', 'TowerUNetDecoder', 'TowerUNetFusion', - 'ResELUNetPsiBlock', - 'ResUNet3_0_4', - 'ResUNet3_1_3', - 'ResUNet3_2_2', - 'ResUNet3_3_1', - 'UNet3_0_4', - 'UNet3_1_3', - 'UNet3_2_2', - 'UNet3_3_1', ] diff --git a/src/cultionet/nn/modules/activations.py b/src/cultionet/nn/modules/activations.py index 2ffb413e..60b26b5e 100644 --- a/src/cultionet/nn/modules/activations.py +++ b/src/cultionet/nn/modules/activations.py @@ -7,7 +7,7 @@ class LogSoftmax(nn.Module): def __init__(self, dim: int = 1): - super(LogSoftmax, self).__init__() + super().__init__() self.dim = dim @@ -17,7 +17,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Softmax(nn.Module): def __init__(self, dim: int = 1): - super(Softmax, self).__init__() + super().__init__() self.dim = dim @@ -27,7 +27,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Swish(nn.Module): def __init__(self, channels: int, dims: int): - super(Swish, self).__init__() + super().__init__() self.sigmoid = nn.Sigmoid() self.beta = nn.Parameter(torch.ones(1)) @@ -57,7 +57,7 @@ def __init__( >>> act = SetActivation('Swish', channels=32) >>> act(x) """ - super(SetActivation, self).__init__() + super().__init__() if activation_type == "Swish": assert isinstance( @@ -118,7 +118,7 @@ class SigmoidCrisp(nn.Module): """ def __init__(self, smooth: float = 1e-2): - super(SigmoidCrisp, self).__init__() + super().__init__() self.smooth = smooth self.gamma = nn.Parameter(torch.ones(1)) diff --git a/src/cultionet/nn/modules/attention.py b/src/cultionet/nn/modules/attention.py index 3fd89ced..a5c2ac08 100644 --- a/src/cultionet/nn/modules/attention.py +++ b/src/cultionet/nn/modules/attention.py @@ -19,7 +19,7 @@ def __init__( add_activation: bool = True, activation_type: str = "SiLU", ): - super(ConvBlock2d, self).__init__() + super().__init__() layers = [ nn.Conv2d( @@ -45,7 +45,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AttentionAdd(nn.Module): def __init__(self): - super(AttentionAdd, self).__init__() + super().__init__() self.up = UpSample() @@ -58,7 +58,7 @@ def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: class AttentionGate(nn.Module): def __init__(self, high_channels: int, low_channels: int): - super(AttentionGate, self).__init__() + super().__init__() conv_x = nn.Conv2d( high_channels, high_channels, kernel_size=1, padding=0 @@ -143,7 +143,7 @@ def __init__( dim: T.Union[int, T.Sequence[int]] = 0, targets_are_labels: bool = True, ): - super(TanimotoComplement, self).__init__() + super().__init__() self.smooth = smooth self.depth = depth @@ -250,7 +250,7 @@ def __init__( weight: T.Optional[torch.Tensor] = None, dim: T.Union[int, T.Sequence[int]] = 0, ): - super(TanimotoDist, self).__init__() + super().__init__() self.smooth = smooth self.weight = weight @@ -318,7 +318,7 @@ class FractalAttention(nn.Module): """ def __init__(self, in_channels: int, out_channels: int): - super(FractalAttention, self).__init__() + super().__init__() self.query = nn.Sequential( ConvBlock2d( @@ -374,7 +374,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ChannelAttention(nn.Module): def __init__(self, out_channels: int, activation_type: str): - super(ChannelAttention, self).__init__() + super().__init__() # Channel attention self.channel_adaptive_avg = nn.AdaptiveAvgPool2d(1) @@ -409,7 +409,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SpatialAttention(nn.Module): def __init__(self): - super(SpatialAttention, self).__init__() + super().__init__() self.conv = nn.Conv2d( in_channels=2, @@ -442,7 +442,7 @@ class SpatialChannelAttention(nn.Module): """ def __init__(self, out_channels: int, activation_type: str): - super(SpatialChannelAttention, self).__init__() + super().__init__() self.channel_attention = ChannelAttention( out_channels=out_channels, activation_type=activation_type diff --git a/src/cultionet/nn/modules/convolution.py b/src/cultionet/nn/modules/convolution.py index ab4e55f1..10cfaabb 100644 --- a/src/cultionet/nn/modules/convolution.py +++ b/src/cultionet/nn/modules/convolution.py @@ -27,7 +27,7 @@ def __init__( stride: int = 2, padding: int = 1, ): - super(ConvTranspose2d, self).__init__() + super().__init__() self.up_conv = nn.ConvTranspose2d( in_channels=in_channels, @@ -57,7 +57,7 @@ def __init__( activation_type: str = "SiLU", batchnorm_first: bool = False, ): - super(ConvBlock2d, self).__init__() + super().__init__() layers = [] @@ -111,7 +111,7 @@ def __init__( double_dilation: int = 1, activation_type: str = "SiLU", ): - super(DoubleConv, self).__init__() + super().__init__() layers = [] @@ -161,7 +161,7 @@ def __init__( dilation_c: int = 3, dilation_d: int = 4, ): - super(AtrousPyramidPooling, self).__init__() + super().__init__() self.up = UpSample() @@ -239,7 +239,7 @@ def __init__( activation_type: str = "SiLU", dropout: T.Optional[float] = None, ): - super(PoolConv, self).__init__() + super().__init__() layers = [nn.MaxPool2d(pool_size)] if dropout is not None: @@ -273,7 +273,7 @@ def __init__( repeat_kernel: bool = False, batchnorm_first: bool = False, ): - super(ResConvBlock2d, self).__init__() + super().__init__() assert ( 0 < num_blocks < 3 @@ -348,7 +348,7 @@ def __init__( activation_type: str = "SiLU", batchnorm_first: bool = False, ): - super(ResidualConv, self).__init__() + super().__init__() self.attention_weights = attention_weights @@ -489,7 +489,7 @@ def __init__( natten_attn_drop: float = 0.0, natten_proj_drop: float = 0.0, ): - super(ResidualAConv, self).__init__() + super().__init__() self.attention_weights = attention_weights self.concat_resid = concat_resid @@ -633,7 +633,7 @@ def __init__( natten_attn_drop: float = 0.0, natten_proj_drop: float = 0.0, ): - super(PoolResidualConv, self).__init__() + super().__init__() assert res_block_type in ( ResBlockTypes.RES, @@ -727,7 +727,7 @@ def __init__( out_channels: int, activation_type: str = "SiLU", ): - super(SingleConv, self).__init__() + super().__init__() self.seq = ConvBlock2d( in_channels=in_channels, @@ -750,7 +750,7 @@ def __init__( final_activation: T.Callable, num_classes: int, ): - super(FinalConv2dDropout, self).__init__() + super().__init__() self.net = nn.Sequential( ResidualConv( diff --git a/src/cultionet/nn/modules/kernels.py b/src/cultionet/nn/modules/kernels.py index ee2e2f62..d2c0fe67 100644 --- a/src/cultionet/nn/modules/kernels.py +++ b/src/cultionet/nn/modules/kernels.py @@ -25,7 +25,7 @@ class Trend3d(torch.nn.Module): def __init__(self, kernel_size: int, direction: str = "positive"): - super(Trend3d, self).__init__() + super().__init__() assert direction in ( "positive", @@ -63,7 +63,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Peaks3d(torch.nn.Module): def __init__(self, kernel_size: int, radius: int = 9, sigma: float = 1.5): - super(Peaks3d, self).__init__() + super().__init__() self.padding = (kernel_size // 2, 0, 0) x = torch.linspace(-radius, radius + 1, kernel_size) diff --git a/src/cultionet/nn/modules/reshape.py b/src/cultionet/nn/modules/reshape.py index 490dbb54..8a2c1fd1 100644 --- a/src/cultionet/nn/modules/reshape.py +++ b/src/cultionet/nn/modules/reshape.py @@ -1,19 +1,14 @@ import typing as T -import einops import torch import torch.nn as nn -def get_batch_count(batch: torch.Tensor) -> int: - return batch.unique().size(0) - - class UpSample(nn.Module): """Up-samples a tensor.""" def __init__(self): - super(UpSample, self).__init__() + super().__init__() def forward( self, x: torch.Tensor, size: T.Sequence[int], mode: str = "bilinear" @@ -21,74 +16,3 @@ def forward( upsampler = nn.Upsample(size=size, mode=mode, align_corners=True) return upsampler(x) - - -class GraphToConv(nn.Module): - """Reshapes a 2d tensor to a 4d tensor.""" - - def __init__(self): - super(GraphToConv, self).__init__() - - def forward( - self, x: torch.Tensor, nbatch: int, nrows: int, ncols: int - ) -> torch.Tensor: - return einops.rearrange( - x, - '(b h w) c -> b c h w', - b=nbatch, - c=x.shape[1], - h=nrows, - w=ncols, - ) - - -class ConvToGraph(nn.Module): - """Reshapes a 4d tensor to a 2d tensor.""" - - def __init__(self): - super(ConvToGraph, self).__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return einops.rearrange(x, 'b c h w -> (b h w) c') - - -class ConvToTime(nn.Module): - """Reshapes a 4d tensor to a 5d tensor.""" - - def __init__(self): - super(ConvToTime, self).__init__() - - def forward( - self, x: torch.Tensor, nbands: int, ntime: int - ) -> torch.Tensor: - nbatch, __, height, width = x.shape - - return einops.rearrange( - x, - 'b (bands t) h w -> b bands t h w', - b=nbatch, - bands=nbands, - t=ntime, - h=height, - w=width, - ) - - -class Squeeze(nn.Module): - def __init__(self, dim: T.Optional[int] = None): - super(Squeeze, self).__init__() - - self.dim = dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.squeeze(dim=self.dim) - - -class Unsqueeze(nn.Module): - def __init__(self, dim: int): - super(Unsqueeze, self).__init__() - - self.dim = dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.unsqueeze(self.dim) diff --git a/src/cultionet/nn/modules/unet_parts.py b/src/cultionet/nn/modules/unet_parts.py index e866253e..45b92e91 100644 --- a/src/cultionet/nn/modules/unet_parts.py +++ b/src/cultionet/nn/modules/unet_parts.py @@ -1,4 +1,3 @@ -import enum import typing as T import torch @@ -6,25 +5,21 @@ from einops import rearrange from einops.layers.torch import Rearrange -from cultionet.enums import AttentionTypes, ModelTypes, ResBlockTypes +from cultionet.enums import AttentionTypes, ResBlockTypes from .activations import SigmoidCrisp -from .attention import AttentionGate from .convolution import ( ConvBlock2d, ConvTranspose2d, - DoubleConv, - PoolConv, PoolResidualConv, ResidualAConv, ResidualConv, ) -from .reshape import UpSample class GeoEmbeddings(nn.Module): def __init__(self, channels: int): - super(GeoEmbeddings, self).__init__() + super().__init__() self.coord_embedding = nn.Linear(3, channels) @@ -53,7 +48,7 @@ def __init__( activation_type: str = "SiLU", resample_factor: int = 0, ): - super(TowerUNetFinal, self).__init__() + super().__init__() self.in_channels = in_channels self.num_classes = num_classes @@ -68,6 +63,7 @@ def __init__( padding=1, ) + # TODO: make optional self.geo_embeddings = GeoEmbeddings(in_channels) self.layernorm = nn.Sequential( Rearrange('b c h w -> b h w c'), @@ -99,7 +95,7 @@ def __init__( def forward( self, x: torch.Tensor, - latlon_coords: T.Optional[torch.Tensor], + latlon_coords: T.Optional[torch.Tensor] = None, size: T.Optional[torch.Size] = None, suffix: str = "", ) -> T.Dict[str, torch.Tensor]: @@ -107,7 +103,11 @@ def forward( x = self.up_conv(x, size=size) # Embed coordinates - x = x + rearrange(self.geo_embeddings(latlon_coords), 'b c -> b c 1 1') + if latlon_coords is not None: + x = x + rearrange( + self.geo_embeddings(latlon_coords), 'b c -> b c 1 1' + ) + x = self.layernorm(x) # Expand into separate streams @@ -151,7 +151,7 @@ def __init__( natten_attn_drop: float = 0.0, natten_proj_drop: float = 0.0, ): - super(UNetUpBlock, self).__init__() + super().__init__() if resample_up: self.up_conv = ConvTranspose2d(in_channels, in_channels) @@ -206,7 +206,7 @@ def __init__( batchnorm_first: bool = False, concat_resid: bool = False, ): - super(TowerUNetEncoder, self).__init__() + super().__init__() # Backbone layers backbone_kwargs = dict( @@ -287,7 +287,7 @@ def __init__( batchnorm_first: bool = False, concat_resid: bool = False, ): - super(TowerUNetDecoder, self).__init__() + super().__init__() # Up layers up_kwargs = dict( @@ -370,7 +370,7 @@ def __init__( batchnorm_first: bool = False, concat_resid: bool = False, ): - super(TowerUNetFusion, self).__init__() + super().__init__() # Towers tower_kwargs = dict( @@ -470,7 +470,7 @@ def __init__( natten_attn_drop: float = 0.0, natten_proj_drop: float = 0.0, ): - super(TowerUNetBlock, self).__init__() + super().__init__() in_channels = ( backbone_side_channels + backbone_down_channels + up_channels * 2 @@ -561,1492 +561,3 @@ def forward( x = torch.cat((x, tower_down), dim=1) return self.res_conv(x) - - -class ResELUNetPsiLayer(nn.Module): - def __init__( - self, - out_channels: int, - side_in: T.Dict[str, int] = None, - down_in: T.Dict[str, int] = None, - dilations: T.Sequence[int] = None, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - activation_type: str = "SiLU", - ): - super(ResELUNetPsiLayer, self).__init__() - - self.up = UpSample() - if dilations is None: - dilations = [2] - - cat_channels = 0 - - module_dict = {} - - if side_in is not None: - for name, in_channels in side_in.items(): - module_dict[name] = ResidualConv( - in_channels=in_channels, - out_channels=out_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - cat_channels += out_channels - - if down_in is not None: - for name, in_channels in down_in.items(): - module_dict[name] = ResidualConv( - in_channels=in_channels, - out_channels=out_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - cat_channels += out_channels - - self.module_dict = nn.ModuleDict(module_dict) - - self.final = ResidualConv( - in_channels=cat_channels, - out_channels=out_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - - def forward( - self, - side: T.Dict[str, torch.Tensor], - down: T.Dict[str, torch.Tensor], - shape: tuple, - ) -> torch.Tensor: - out = [] - for name, x in side.items(): - layer = self.module_dict[name] - assert x is not None, 'A tensor must be given.' - out += [layer(x)] - - for name, x in down.items(): - layer = self.module_dict[name] - x = self.up( - x, - size=shape, - mode="bilinear", - ) - out += [layer(x)] - - out = torch.cat(out, dim=1) - out = self.final(out) - - return out - - -class ResELUNetPsiBlock(nn.Module): - def __init__( - self, - out_channels: int, - side_in: dict, - down_in: dict, - dilations: T.Sequence[int] = None, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - activation_type: str = "SiLU", - ): - super(ResELUNetPsiBlock, self).__init__() - - self.dist_layer = ResELUNetPsiLayer( - out_channels=out_channels, - side_in=side_in['dist'], - down_in=down_in['dist'], - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.edge_layer = ResELUNetPsiLayer( - out_channels=out_channels, - side_in=side_in['edge'], - down_in=down_in['edge'], - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.mask_layer = ResELUNetPsiLayer( - out_channels=out_channels, - side_in=side_in['mask'], - down_in=down_in['mask'], - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - - def update_data( - self, - data_dict: T.Dict[str, T.Union[None, torch.Tensor]], - data: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - out = data_dict.copy() - for key, x in data_dict.items(): - if x is None: - out[key] = data - - return out - - def forward( - self, - side: T.Dict[str, T.Union[None, torch.Tensor]], - down: T.Dict[str, T.Union[None, torch.Tensor]], - shape: tuple, - ) -> dict: - dist_out = self.dist_layer( - side=side['dist'], - down=down['dist'], - shape=shape, - ) - - edge_out = self.edge_layer( - side=self.update_data(side['edge'], dist_out), - down=down['edge'], - shape=shape, - ) - - mask_out = self.mask_layer( - side=self.update_data(side['mask'], edge_out), - down=down['mask'], - shape=shape, - ) - - return { - "dist": dist_out, - "edge": edge_out, - "mask": mask_out, - } - - -class UNet3Connector(nn.Module): - """Connects layers in a UNet 3+ architecture.""" - - def __init__( - self, - channels: T.List[int], - up_channels: int, - prev_backbone_channel_index: int, - use_backbone: bool = True, - is_side_stream: bool = True, - n_pools: int = 0, - n_prev_down: int = 0, - n_stream_down: int = 0, - prev_down_is_pooled: bool = False, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - model_type: str = ModelTypes.UNET, - res_block_type: str = ResBlockTypes.RESA, - activation_type: str = "SiLU", - ): - super(UNet3Connector, self).__init__() - - assert attention_weights in [ - "gate", - AttentionTypes.FRACTAL, - AttentionTypes.SPATIAL_CHANNEL, - ], "Choose from 'gate', 'fractal', or 'spatial_channel' attention weights." - - assert model_type in ( - ModelTypes.UNET, - ModelTypes.RESUNET, - ModelTypes.RESUNET3PSI, - ModelTypes.RESELUNETPSI, - ) - assert res_block_type in ( - ResBlockTypes.RES, - ResBlockTypes.RESA, - ) - - self.n_pools = n_pools - self.n_prev_down = n_prev_down - self.n_stream_down = n_stream_down - self.attention_weights = attention_weights - self.use_backbone = use_backbone - self.is_side_stream = is_side_stream - self.cat_channels = 0 - self.pool4_0 = None - - self.up = UpSample() - - if dilations is None: - dilations = [2] - - # Pool layers - if n_pools > 0: - if n_pools == 3: - pool_size = 8 - elif n_pools == 2: - pool_size = 4 - else: - pool_size = 2 - - for n in range(0, n_pools): - if model_type == ModelTypes.UNET: - setattr( - self, - f"pool_{n}", - PoolConv( - in_channels=channels[n], - out_channels=channels[0], - pool_size=pool_size, - double_dilation=dilations[0], - activation_type=activation_type, - ), - ) - else: - setattr( - self, - f"pool_{n}", - PoolResidualConv( - in_channels=channels[n], - out_channels=channels[0], - pool_size=pool_size, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - res_block_type=res_block_type, - ), - ) - pool_size = int(pool_size / 2) - self.cat_channels += channels[0] - if self.use_backbone: - if model_type == ModelTypes.UNET: - self.prev_backbone = DoubleConv( - in_channels=channels[prev_backbone_channel_index], - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ) - else: - if res_block_type == ResBlockTypes.RES: - self.prev_backbone = ResidualConv( - in_channels=channels[prev_backbone_channel_index], - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - else: - self.prev_backbone = ResidualAConv( - in_channels=channels[prev_backbone_channel_index], - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.cat_channels += up_channels - if self.is_side_stream: - if model_type == ModelTypes.UNET: - # Backbone, same level - self.prev = DoubleConv( - in_channels=up_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ) - else: - if res_block_type == ResBlockTypes.RES: - self.prev = ResidualConv( - in_channels=up_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - else: - self.prev = ResidualAConv( - in_channels=up_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.cat_channels += up_channels - # Previous output, downstream - if self.n_prev_down > 0: - for n in range(0, self.n_prev_down): - if model_type == ModelTypes.UNET: - setattr( - self, - f"prev_{n}", - DoubleConv( - in_channels=up_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ), - ) - else: - if res_block_type == ResBlockTypes.RES: - setattr( - self, - f"prev_{n}", - ResidualConv( - in_channels=up_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - else: - setattr( - self, - f"prev_{n}", - ResidualAConv( - in_channels=up_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - self.cat_channels += up_channels - - # Previous output, (same) downstream - if self.n_stream_down > 0: - for n in range(0, self.n_stream_down): - in_stream_channels = up_channels - if self.attention_weights is not None and ( - self.attention_weights == "gate" - ): - attention_module = AttentionGate(up_channels, up_channels) - setattr(self, f"attn_stream_{n}", attention_module) - in_stream_channels = up_channels * 2 - - # All but the last inputs are pooled - if prev_down_is_pooled and (n + 1 < self.n_stream_down): - in_stream_channels = channels[ - prev_backbone_channel_index - + (self.n_stream_down - 1) - - n - ] - - if model_type == ModelTypes.UNET: - setattr( - self, - f"stream_{n}", - DoubleConv( - in_channels=in_stream_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ), - ) - else: - if res_block_type == ResBlockTypes.RES: - setattr( - self, - f"stream_{n}", - ResidualConv( - in_channels=in_stream_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - else: - setattr( - self, - f"stream_{n}", - ResidualAConv( - in_channels=in_stream_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - self.cat_channels += up_channels - - self.cat_channels += channels[0] - if model_type == ModelTypes.UNET: - self.conv4_0 = DoubleConv( - in_channels=channels[4], - out_channels=channels[0], - init_point_conv=init_point_conv, - activation_type=activation_type, - ) - self.final = DoubleConv( - in_channels=self.cat_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ) - else: - if res_block_type == ResBlockTypes.RES: - self.conv4_0 = ResidualConv( - in_channels=channels[4], - out_channels=channels[0], - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.final = ResidualConv( - in_channels=self.cat_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - else: - self.conv4_0 = ResidualAConv( - in_channels=channels[4], - out_channels=channels[0], - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.final = ResidualAConv( - in_channels=self.cat_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - - def forward( - self, - prev_same: T.List[T.Tuple[str, torch.Tensor]], - x4_0: torch.Tensor = None, - pools: T.List[torch.Tensor] = None, - prev_down: T.List[torch.Tensor] = None, - stream_down: T.List[torch.Tensor] = None, - ): - h: T.List[torch.Tensor] = [] - # Pooling layer of the backbone - if pools is not None: - assert self.n_pools == len( - pools - ), "There are no convolutions available for the pool layers." - for n, x in zip(range(self.n_pools), pools): - c = getattr(self, f"pool_{n}") - h += [c(x)] - # Up down layers from the previous head - if prev_down is not None: - assert self.n_prev_down == len( - prev_down - ), "There are no convolutions available for the previous downstream layers." - for n, x in zip(range(self.n_prev_down), prev_down): - c = getattr(self, f"prev_{n}") - h += [ - c( - self.up( - x, size=prev_same[0][1].shape[-2:], mode="bilinear" - ) - ) - ] - assert len(prev_same) == sum( - [self.use_backbone, self.is_side_stream] - ), "The previous same layers do not match the setup." - # Previous same layers from the previous head - for conv_name, prev_inputs in prev_same: - c = getattr(self, conv_name) - h += [c(prev_inputs)] - if self.attention_weights is not None and ( - self.attention_weights == "gate" - ): - prev_same_hidden = h[-1].clone() - # Previous down layers from the same head - if stream_down is not None: - assert self.n_stream_down == len( - stream_down - ), "There are no convolutions available for the downstream layers." - for n, x in zip(range(self.n_stream_down), stream_down): - if self.attention_weights is not None and ( - self.attention_weights == "gate" - ): - # Gate - g = self.up( - x, size=prev_same[0][1].shape[-2:], mode="bilinear" - ) - c_attn = getattr(self, f"attn_stream_{n}") - # Attention gate - attn_out = c_attn(g, prev_same_hidden) - c = getattr(self, f"stream_{n}") - # Concatenate attention weights - h += [c(torch.cat([attn_out, g], dim=1))] - else: - c = getattr(self, f"stream_{n}") - h += [ - c( - self.up( - x, - size=prev_same[0][1].shape[-2:], - mode="bilinear", - ) - ) - ] - - # Lowest level - if x4_0 is not None: - x4_0_up = self.conv4_0( - self.up(x4_0, size=prev_same[0][1].shape[-2:], mode="bilinear") - ) - if self.pool4_0 is not None: - h += [self.pool4_0(x4_0_up)] - else: - h += [x4_0_up] - h = torch.cat(h, dim=1) - h = self.final(h) - - return h - - -class UNet3P_3_1(nn.Module): - """UNet 3+ connection from backbone to upstream 3,1.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "SiLU", - ): - super(UNet3P_3_1, self).__init__() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - x3_0: torch.Tensor, - x4_0: torch.Tensor, - ) -> torch.Tensor: - h = self.conv( - prev_same=[("prev_backbone", x3_0)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - - return h - - -class UNet3P_2_2(nn.Module): - """UNet 3+ connection from backbone to upstream 2,2.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "SiLU", - ): - super(UNet3P_2_2, self).__init__() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - h3_1: torch.Tensor, - x4_0: torch.Tensor, - ) -> torch.Tensor: - h = self.conv( - prev_same=[("prev_backbone", x2_0)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1], - ) - - return h - - -class UNet3P_1_3(nn.Module): - """UNet 3+ connection from backbone to upstream 1,3.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "SiLU", - ): - super(UNet3P_1_3, self).__init__() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - h2_2: torch.Tensor, - h3_1: torch.Tensor, - x4_0: torch.Tensor, - ) -> torch.Tensor: - h = self.conv( - prev_same=[("prev_backbone", x1_0)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1, h2_2], - ) - - return h - - -class UNet3P_0_4(nn.Module): - """UNet 3+ connection from backbone to upstream 0,4.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "SiLU", - ): - super(UNet3P_0_4, self).__init__() - - self.up = UpSample() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - h1_3: torch.Tensor, - h2_2: torch.Tensor, - h3_1: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h = self.conv( - prev_same=[("prev_backbone", x0_0)], - x4_0=x4_0, - stream_down=[h3_1, h2_2, h1_3], - ) - - return h - - -class UNet3_3_1(nn.Module): - """UNet 3+ connection from backbone to upstream 3,1.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "SiLU", - ): - super(UNet3_3_1, self).__init__() - - self.up = UpSample() - - # Distance stream connection - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - # Edge stream connection - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - # Mask stream connection - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - x3_0: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - # Distance logits - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x3_0)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - # Output distance logits pass to edge layer - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x3_0), ("prev", h_dist)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - # Output edge logits pass to mask layer - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x3_0), ("prev", h_edge)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class UNet3_2_2(nn.Module): - """UNet 3+ connection from backbone to upstream 2,2.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "SiLU", - ): - super(UNet3_2_2, self).__init__() - - self.up = UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - h3_1_dist: torch.Tensor, - h3_1_edge: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x2_0)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x2_0), ("prev", h_dist)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x2_0), ("prev", h_edge)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class UNet3_1_3(nn.Module): - """UNet 3+ connection from backbone to upstream 1,3.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "SiLU", - ): - super(UNet3_1_3, self).__init__() - - self.up = UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - h2_2_dist: torch.Tensor, - h3_1_dist: torch.Tensor, - h2_2_edge: torch.Tensor, - h3_1_edge: torch.Tensor, - h2_2_mask: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x1_0)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_dist, h2_2_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x1_0), ("prev", h_dist)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_edge, h2_2_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x1_0), ("prev", h_edge)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_mask, h2_2_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class UNet3_0_4(nn.Module): - """UNet 3+ connection from backbone to upstream 0,4.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "SiLU", - ): - super(UNet3_0_4, self).__init__() - - self.up = UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - h1_3_dist: torch.Tensor, - h2_2_dist: torch.Tensor, - h3_1_dist: torch.Tensor, - h1_3_edge: torch.Tensor, - h2_2_edge: torch.Tensor, - h3_1_edge: torch.Tensor, - h1_3_mask: torch.Tensor, - h2_2_mask: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x0_0)], - x4_0=x4_0, - stream_down=[h3_1_dist, h2_2_dist, h1_3_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x0_0), ("prev", h_dist)], - x4_0=x4_0, - stream_down=[h3_1_edge, h2_2_edge, h1_3_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x0_0), ("prev", h_edge)], - x4_0=x4_0, - stream_down=[h3_1_mask, h2_2_mask, h1_3_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -def get_prev_list( - use_backbone: bool, - x: torch.Tensor, - prev_same: T.List[tuple], -) -> T.List[tuple]: - prev = [ - ( - "prev", - x, - ) - ] - if use_backbone: - prev += prev_same - - return prev - - -class ResUNet3_3_1(nn.Module): - """Residual UNet 3+ connection from backbone to upstream 3,1.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - n_pools: int = 3, - use_backbone: bool = True, - dilations: T.Sequence[int] = None, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - activation_type: str = "SiLU", - res_block_type: str = ResBlockTypes.RESA, - model_type: str = ModelTypes.RESUNET, - ): - super(ResUNet3_3_1, self).__init__() - - self.use_backbone = use_backbone - self.up = UpSample() - - # Distance stream connection - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=3, - n_pools=n_pools, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - # Edge stream connection - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=3, - n_pools=n_pools, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - # Mask stream connection - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=3, - n_pools=n_pools, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - side: torch.Tensor, - down: torch.Tensor, - pools: T.Sequence[torch.Tensor] = None, - ) -> T.Dict[str, torch.Tensor]: - prev_same = [ - ( - "prev_backbone", - side, - ) - ] - # Distance logits - h_dist = self.conv_dist( - prev_same=prev_same, - pools=pools, - x4_0=down, - ) - # Output distance logits pass to edge layer - h_edge = self.conv_edge( - prev_same=get_prev_list(self.use_backbone, h_dist, prev_same), - pools=pools, - x4_0=down, - ) - # Output edge logits pass to mask layer - h_mask = self.conv_mask( - prev_same=get_prev_list(self.use_backbone, h_edge, prev_same), - pools=pools, - x4_0=down, - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class ResUNet3_2_2(nn.Module): - """Residual UNet 3+ connection from backbone to upstream 2,2.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - n_pools: int = 2, - use_backbone: bool = True, - n_stream_down: int = 1, - prev_down_is_pooled: bool = False, - dilations: T.Sequence[int] = None, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - activation_type: str = "SiLU", - res_block_type: str = ResBlockTypes.RESA, - model_type: str = ModelTypes.RESUNET, - ): - super(ResUNet3_2_2, self).__init__() - - self.use_backbone = use_backbone - self.up = UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=2, - n_pools=n_pools, - n_stream_down=n_stream_down, - prev_down_is_pooled=prev_down_is_pooled, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=2, - n_pools=n_pools, - n_stream_down=n_stream_down, - prev_down_is_pooled=False, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=2, - n_pools=n_pools, - n_stream_down=n_stream_down, - prev_down_is_pooled=False, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - side: torch.Tensor, - dist_down: T.Sequence[torch.Tensor], - edge_down: T.Sequence[torch.Tensor], - mask_down: T.Sequence[torch.Tensor], - down: torch.Tensor = None, - pools: T.Sequence[torch.Tensor] = None, - ) -> T.Dict[str, torch.Tensor]: - prev_same = [ - ( - "prev_backbone", - side, - ) - ] - - h_dist = self.conv_dist( - prev_same=prev_same, - pools=pools, - x4_0=down, - stream_down=dist_down, - ) - h_edge = self.conv_edge( - prev_same=get_prev_list(self.use_backbone, h_dist, prev_same), - pools=pools, - x4_0=down, - stream_down=edge_down, - ) - h_mask = self.conv_mask( - prev_same=get_prev_list(self.use_backbone, h_edge, prev_same), - pools=pools, - x4_0=down, - stream_down=mask_down, - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class ResUNet3_1_3(nn.Module): - """Residual UNet 3+ connection from backbone to upstream 1,3.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - n_pools: int = 1, - use_backbone: bool = True, - n_stream_down: int = 2, - prev_down_is_pooled: bool = False, - dilations: T.Sequence[int] = None, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - activation_type: str = "SiLU", - res_block_type: enum = ResBlockTypes.RESA, - model_type: str = ModelTypes.RESUNET, - ): - super(ResUNet3_1_3, self).__init__() - - self.use_backbone = use_backbone - self.up = UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=1, - n_pools=n_pools, - n_stream_down=n_stream_down, - prev_down_is_pooled=prev_down_is_pooled, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=1, - n_pools=n_pools, - n_stream_down=n_stream_down, - prev_down_is_pooled=prev_down_is_pooled, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=1, - n_pools=n_pools, - n_stream_down=n_stream_down, - prev_down_is_pooled=prev_down_is_pooled, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - side: torch.Tensor, - dist_down: T.Sequence[torch.Tensor], - edge_down: T.Sequence[torch.Tensor], - mask_down: T.Sequence[torch.Tensor], - down: torch.Tensor = None, - pools: T.Sequence[torch.Tensor] = None, - ) -> T.Dict[str, torch.Tensor]: - prev_same = [ - ( - "prev_backbone", - side, - ) - ] - - h_dist = self.conv_dist( - prev_same=prev_same, - pools=pools, - x4_0=down, - stream_down=dist_down, - ) - h_edge = self.conv_edge( - prev_same=get_prev_list(self.use_backbone, h_dist, prev_same), - pools=pools, - x4_0=down, - stream_down=edge_down, - ) - h_mask = self.conv_mask( - prev_same=get_prev_list(self.use_backbone, h_edge, prev_same), - pools=pools, - x4_0=down, - stream_down=mask_down, - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class ResUNet3_0_4(nn.Module): - """Residual UNet 3+ connection from backbone to upstream 0,4.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - n_stream_down: int = 3, - use_backbone: bool = True, - prev_down_is_pooled: bool = False, - dilations: T.Sequence[int] = None, - attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, - activation_type: str = "SiLU", - res_block_type: str = ResBlockTypes.RESA, - model_type: str = ModelTypes.RESUNET, - ): - super(ResUNet3_0_4, self).__init__() - - self.use_backbone = use_backbone - self.up = UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=0, - n_stream_down=n_stream_down, - prev_down_is_pooled=prev_down_is_pooled, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=0, - n_stream_down=n_stream_down, - prev_down_is_pooled=prev_down_is_pooled, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=use_backbone, - is_side_stream=True, - prev_backbone_channel_index=0, - n_stream_down=n_stream_down, - prev_down_is_pooled=prev_down_is_pooled, - dilations=dilations, - attention_weights=attention_weights, - model_type=model_type, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - side: torch.Tensor, - dist_down: T.Sequence[torch.Tensor], - edge_down: T.Sequence[torch.Tensor], - mask_down: T.Sequence[torch.Tensor], - down: torch.Tensor = None, - ) -> T.Dict[str, torch.Tensor]: - prev_same = [ - ( - "prev_backbone", - side, - ) - ] - - h_dist = self.conv_dist( - prev_same=prev_same, - x4_0=down, - stream_down=dist_down, - ) - h_edge = self.conv_edge( - prev_same=get_prev_list(self.use_backbone, h_dist, prev_same), - x4_0=down, - stream_down=edge_down, - ) - h_mask = self.conv_mask( - prev_same=get_prev_list(self.use_backbone, h_edge, prev_same), - x4_0=down, - stream_down=mask_down, - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } diff --git a/src/cultionet/nn/modules/utils.py b/src/cultionet/nn/modules/utils.py index 6fca55ee..ffd983b0 100644 --- a/src/cultionet/nn/modules/utils.py +++ b/src/cultionet/nn/modules/utils.py @@ -1,7 +1,4 @@ -import typing as T - import torch -import torch.nn as nn import torch.nn.functional as F @@ -15,85 +12,3 @@ def check_upsample(x: torch.Tensor, size: torch.Size) -> torch.Tensor: ) return x - - -class Permute(nn.Module): - def __init__(self, axis_order: T.Sequence[int]): - super(Permute, self).__init__() - self.axis_order = axis_order - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.permute(*self.axis_order) - - -class Add(nn.Module): - def __init__(self): - super(Add, self).__init__() - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y - - -class Min(nn.Module): - def __init__(self, dim: int, keepdim: bool = False): - super(Min, self).__init__() - - self.dim = dim - self.keepdim = keepdim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.min(dim=self.dim, keepdim=self.keepdim)[0] - - -class Max(nn.Module): - def __init__(self, dim: int, keepdim: bool = False): - super(Max, self).__init__() - - self.dim = dim - self.keepdim = keepdim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.max(dim=self.dim, keepdim=self.keepdim)[0] - - -class Mean(nn.Module): - def __init__(self, dim: int, keepdim: bool = False): - super(Mean, self).__init__() - - self.dim = dim - self.keepdim = keepdim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.mean(dim=self.dim, keepdim=self.keepdim) - - -class Var(nn.Module): - def __init__( - self, dim: int, keepdim: bool = False, unbiased: bool = False - ): - super(Var, self).__init__() - - self.dim = dim - self.keepdim = keepdim - self.unbiased = unbiased - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.var( - dim=self.dim, keepdim=self.keepdim, unbiased=self.unbiased - ) - - -class Std(nn.Module): - def __init__( - self, dim: int, keepdim: bool = False, unbiased: bool = False - ): - super(Std, self).__init__() - - self.dim = dim - self.keepdim = keepdim - self.unbiased = unbiased - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.std( - dim=self.dim, keepdim=self.keepdim, unbiased=self.unbiased - ) diff --git a/src/cultionet/scripts/args.yml b/src/cultionet/scripts/args.yml index 288349d4..8e12c157 100644 --- a/src/cultionet/scripts/args.yml +++ b/src/cultionet/scripts/args.yml @@ -181,8 +181,6 @@ train_predict: kwargs: default: 'TowerUNet' choices: - - 'UNet3Psi' - - 'ResUNet3Psi' - 'TowerUNet' activation_type: short: '' @@ -193,7 +191,7 @@ train_predict: res_block_type: short: rb long: res-block-type - help: The residual block type (only relevant when --model-type=ResUNet3Psi) + help: The residual block type) kwargs: default: 'resa' choices: ['res', 'resa'] diff --git a/src/cultionet/scripts/cultionet.py b/src/cultionet/scripts/cultionet.py index 26999034..f7fb9a4b 100644 --- a/src/cultionet/scripts/cultionet.py +++ b/src/cultionet/scripts/cultionet.py @@ -45,7 +45,8 @@ from cultionet.errors import TensorShapeError from cultionet.model import CultionetParams from cultionet.utils import model_preprocessing -from cultionet.utils.logging import ParallelProgress, set_color_logger +from cultionet.utils.logging import set_color_logger +from cultionet.utils.model_preprocessing import ParallelProgress from cultionet.utils.normalize import NormValues from cultionet.utils.project_paths import ProjectPaths, setup_paths diff --git a/src/cultionet/utils/model_preprocessing.py b/src/cultionet/utils/model_preprocessing.py index 440a5cb5..cfa78f30 100644 --- a/src/cultionet/utils/model_preprocessing.py +++ b/src/cultionet/utils/model_preprocessing.py @@ -8,7 +8,7 @@ from tqdm.auto import tqdm -class TqdmParallel(Parallel): +class ParallelProgress(Parallel): """A tqdm progress bar for joblib Parallel tasks. Reference: diff --git a/src/cultionet/utils/normalize.py b/src/cultionet/utils/normalize.py index 1d1ef310..4277409f 100644 --- a/src/cultionet/utils/normalize.py +++ b/src/cultionet/utils/normalize.py @@ -19,7 +19,7 @@ from ..data.data import Data from ..data.utils import collate_fn -from .model_preprocessing import TqdmParallel +from .model_preprocessing import ParallelProgress from .stats import Quantile, Variance, cache_load_enabled, tally_stats @@ -28,6 +28,8 @@ def add_dim(d: torch.Tensor) -> torch.Tensor: class NormValues: + """Normalization values.""" + def __init__( self, dataset_mean: torch.Tensor, @@ -131,199 +133,145 @@ def from_dataset( lower_bound = None upper_bound = None - if not isinstance(dataset, Dataset): - data_loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=False, - num_workers=0, - collate_fn=collate_fn, - ) - - data_maxs = torch.zeros(3, dtype=torch.float) - data_sums = torch.zeros(3, dtype=torch.float) - sse = torch.zeros(3, dtype=torch.float) - pix_count = 0.0 - with tqdm( - total=int(len(dataset) / batch_size), - desc='Calculating means', - colour=mean_color, - ) as pbar: - for x, y in data_loader: - channel_maxs = torch.tensor( - [x[0, c, ...].max() for c in range(0, x.shape[1])], - dtype=torch.float, - ) - data_maxs = torch.where( - channel_maxs > data_maxs, channel_maxs, data_maxs - ) - # Sum over all data - data_sums += x.sum(dim=(0, 2, 3)) - pix_count += x.shape[2] * x.shape[3] + data_loader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + collate_fn=collate_fn, + ) - pbar.update(1) + if centering == 'median': + stat_var = Variance(method='median') + stat_q = Quantile(r=1024 * 6) + tmp_cache_path = Path.home().absolute() / '.cultionet' + tmp_cache_path.mkdir(parents=True, exist_ok=True) + var_data_cache = tmp_cache_path / '_var.npz' + q_data_cache = tmp_cache_path / '_q.npz' + crop_counts = torch.zeros(class_info['max_crop_class'] + 1).long() + edge_counts = torch.zeros(2).long() + with cache_load_enabled(True): + with Progress( + TextColumn( + "Calculating stats", style=Style(color="#cacaca") + ), + TextColumn("•", style=Style(color="#cacaca")), + BarColumn( + style="#ACFCD6", + complete_style="#AA9439", + finished_style="#ACFCD6", + pulse_style="#FCADED", + ), + TaskProgressColumn(), + TextColumn("•", style=Style(color="#cacaca")), + TimeElapsedColumn(), + ) as pbar: + for batch in pbar.track( + tally_stats( + stats=(stat_var, stat_q), + loader=data_loader, + caches=(var_data_cache, q_data_cache), + ), + total=len(data_loader), + ): + # Stack samples + x = rearrange(batch.x, 'b c t h w -> (b t h w) c') + + # Update the stats + stat_var.add(x) + stat_q.add(x) + + # Update counts + crop_counts[0] += ( + (batch.y == 0) + | (batch.y == class_info['edge_class']) + ).sum() + for i in range(1, class_info['edge_class']): + crop_counts[i] += (batch.y == i).sum() + + edge_counts[0] += ( + (batch.y >= 0) + & (batch.y != class_info['edge_class']) + ).sum() + edge_counts[1] += ( + batch.y == class_info['edge_class'] + ).sum() + + data_stds = stat_var.std() + data_means = stat_q.median() + lower_bound = stat_q.quantiles(0.3) + upper_bound = stat_q.quantiles(0.7) + + var_data_cache.unlink() + q_data_cache.unlink() + tmp_cache_path.rmdir() - data_means = data_sums / float(pix_count) - with tqdm( - total=int(len(dataset) / batch_size), - desc='Calculating SSEs', - colour=sse_color, - ) as pbar: - for x, y in data_loader: - sse += ( - (x - data_means.unsqueeze(0)[..., None, None]).pow(2) - ).sum(dim=(0, 2, 3)) + else: - pbar.update(1) + def get_info( + x: torch.Tensor, y: torch.Tensor + ) -> T.Tuple[torch.Tensor, int, torch.Tensor, torch.Tensor]: + crop_counts = torch.zeros(class_info['max_crop_class'] + 1) + edge_counts = torch.zeros(2) + crop_counts[0] = ( + (y == 0) | (y == class_info['edge_class']) + ).sum() + for i in range(1, class_info['edge_class']): + crop_counts[i] = (y == i).sum() + edge_counts[0] = (y != class_info['edge_class']).sum() + edge_counts[1] = (y == class_info['edge_class']).sum() + + return x.sum(dim=0), x.shape[0], crop_counts, edge_counts + + with parallel_backend( + backend='loky', + n_jobs=processes, + inner_max_num_threads=threads_per_worker, + ): + with ParallelProgress( + tqdm_kwargs={ + 'total': int(len(dataset) / batch_size), + 'desc': 'Calculating means', + 'colour': mean_color, + } + ) as pool: + results = pool( + delayed(get_info)(batch.x, batch.y) + for batch in data_loader + ) + data_sums, pix_count, crop_counts, edge_counts = list( + map(list, zip(*results)) + ) - data_stds = torch.sqrt(sse / pix_count) + data_sums = torch.stack(data_sums).sum(dim=0) + pix_count = torch.tensor(pix_count).sum() + crop_counts = torch.stack(crop_counts).sum(dim=0) + edge_counts = torch.stack(edge_counts).sum(dim=0) + data_means = data_sums / float(pix_count) - else: - data_loader = DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - shuffle=False, - collate_fn=collate_fn, - ) + def get_sse(x_mu: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + return ((x - x_mu).pow(2)).sum(dim=0) + + sse_partial = partial(get_sse, add_dim(data_means)) + + with parallel_backend( + backend='loky', + n_jobs=processes, + inner_max_num_threads=threads_per_worker, + ): + with ParallelProgress( + tqdm_kwargs={ + 'total': int(len(dataset) / batch_size), + 'desc': 'Calculating SSEs', + 'colour': sse_color, + } + ) as pool: + sses = pool( + delayed(sse_partial)(batch.x) for batch in data_loader + ) - if centering == 'median': - stat_var = Variance(method='median') - stat_q = Quantile(r=1024 * 6) - tmp_cache_path = Path.home().absolute() / '.cultionet' - tmp_cache_path.mkdir(parents=True, exist_ok=True) - var_data_cache = tmp_cache_path / '_var.npz' - q_data_cache = tmp_cache_path / '_q.npz' - crop_counts = torch.zeros( - class_info['max_crop_class'] + 1 - ).long() - edge_counts = torch.zeros(2).long() - with cache_load_enabled(True): - with Progress( - TextColumn( - "Calculating stats", style=Style(color="#cacaca") - ), - TextColumn("•", style=Style(color="#cacaca")), - BarColumn( - style="#ACFCD6", - complete_style="#AA9439", - finished_style="#ACFCD6", - pulse_style="#FCADED", - ), - TaskProgressColumn(), - TextColumn("•", style=Style(color="#cacaca")), - TimeElapsedColumn(), - ) as pbar: - for batch in pbar.track( - tally_stats( - stats=(stat_var, stat_q), - loader=data_loader, - caches=(var_data_cache, q_data_cache), - ), - total=len(data_loader), - ): - # Stack samples - x = rearrange(batch.x, 'b c t h w -> (b t h w) c') - - # Update the stats - stat_var.add(x) - stat_q.add(x) - - # Update counts - crop_counts[0] += ( - (batch.y == 0) - | (batch.y == class_info['edge_class']) - ).sum() - for i in range(1, class_info['edge_class']): - crop_counts[i] += (batch.y == i).sum() - - edge_counts[0] += ( - (batch.y >= 0) - & (batch.y != class_info['edge_class']) - ).sum() - edge_counts[1] += ( - batch.y == class_info['edge_class'] - ).sum() - - data_stds = stat_var.std() - data_means = stat_q.median() - lower_bound = stat_q.quantiles(0.3) - upper_bound = stat_q.quantiles(0.7) - - var_data_cache.unlink() - q_data_cache.unlink() - tmp_cache_path.rmdir() - - else: - - def get_info( - x: torch.Tensor, y: torch.Tensor - ) -> T.Tuple[torch.Tensor, int, torch.Tensor, torch.Tensor]: - crop_counts = torch.zeros(class_info['max_crop_class'] + 1) - edge_counts = torch.zeros(2) - crop_counts[0] = ( - (y == 0) | (y == class_info['edge_class']) - ).sum() - for i in range(1, class_info['edge_class']): - crop_counts[i] = (y == i).sum() - edge_counts[0] = (y != class_info['edge_class']).sum() - edge_counts[1] = (y == class_info['edge_class']).sum() - - return x.sum(dim=0), x.shape[0], crop_counts, edge_counts - - with parallel_backend( - backend='loky', - n_jobs=processes, - inner_max_num_threads=threads_per_worker, - ): - with TqdmParallel( - tqdm_kwargs={ - 'total': int(len(dataset) / batch_size), - 'desc': 'Calculating means', - 'colour': mean_color, - } - ) as pool: - results = pool( - delayed(get_info)(batch.x, batch.y) - for batch in data_loader - ) - data_sums, pix_count, crop_counts, edge_counts = list( - map(list, zip(*results)) - ) - - data_sums = torch.stack(data_sums).sum(dim=0) - pix_count = torch.tensor(pix_count).sum() - crop_counts = torch.stack(crop_counts).sum(dim=0) - edge_counts = torch.stack(edge_counts).sum(dim=0) - data_means = data_sums / float(pix_count) - - def get_sse( - x_mu: torch.Tensor, x: torch.Tensor - ) -> torch.Tensor: - return ((x - x_mu).pow(2)).sum(dim=0) - - sse_partial = partial(get_sse, add_dim(data_means)) - - with parallel_backend( - backend='loky', - n_jobs=processes, - inner_max_num_threads=threads_per_worker, - ): - with TqdmParallel( - tqdm_kwargs={ - 'total': int(len(dataset) / batch_size), - 'desc': 'Calculating SSEs', - 'colour': sse_color, - } - ) as pool: - sses = pool( - delayed(sse_partial)(batch.x) - for batch in data_loader - ) - - sses = torch.stack(sses).sum(dim=0) - data_stds = torch.sqrt(sses / float(pix_count)) - data_maxs = torch.zeros_like(data_means) + sses = torch.stack(sses).sum(dim=0) + data_stds = torch.sqrt(sses / float(pix_count)) return cls( dataset_mean=rearrange(data_means, 'c -> 1 c 1 1 1'), diff --git a/tests/test_temporal_transformer.py b/tests/test_temporal_transformer.py new file mode 100644 index 00000000..8c77c456 --- /dev/null +++ b/tests/test_temporal_transformer.py @@ -0,0 +1,31 @@ +import torch + +from cultionet.models.temporal_transformer import TemporalTransformer + + +def test_temporal_transformer(): + batch_size = 2 + num_channels = 3 + hidden_channels = 64 + num_head = 8 + d_model = 128 + in_time = 13 + height = 100 + width = 100 + + x = torch.rand( + (batch_size, num_channels, in_time, height, width), + dtype=torch.float32, + ) + + model = TemporalTransformer( + in_channels=num_channels, + hidden_channels=hidden_channels, + num_head=num_head, + d_model=d_model, + in_time=in_time, + ) + output = model(x) + + assert tuple(output.keys()) == ('l2', 'l3', 'encoded') + output['encoded'].shape == (batch_size, hidden_channels, height, width) diff --git a/tests/test_tower_unet.py b/tests/test_tower_unet.py new file mode 100644 index 00000000..7b35c859 --- /dev/null +++ b/tests/test_tower_unet.py @@ -0,0 +1,43 @@ +import torch + +from cultionet.enums import AttentionTypes, ResBlockTypes +from cultionet.models.nunet import TowerUNet + + +def test_tower_unet(): + batch_size = 2 + num_channels = 3 + hidden_channels = 32 + num_time = 13 + height = 100 + width = 100 + + x = torch.rand( + (batch_size, num_channels, num_time, height, width), + dtype=torch.float32, + ) + logits_hidden = torch.rand( + (batch_size, hidden_channels, height, width), dtype=torch.float32 + ) + + model = TowerUNet( + in_channels=num_channels, + in_time=num_time, + hidden_channels=hidden_channels, + dilations=[1, 2], + dropout=0.2, + res_block_type=ResBlockTypes.RESA, + attention_weights=AttentionTypes.SPATIAL_CHANNEL, + deep_supervision=False, + pool_attention=False, + pool_by_max=False, + repeat_resa_kernel=False, + batchnorm_first=True, + concat_resid=False, + ) + + logits = model(x, temporal_encoding=logits_hidden) + + assert logits['dist'].shape == (batch_size, 1, height, width) + assert logits['edge'].shape == (batch_size, 1, height, width) + assert logits['mask'].shape == (batch_size, 2, height, width) diff --git a/tests/test_train.py b/tests/test_train.py index ee437114..558716fa 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,7 +1,10 @@ +import json +import subprocess import tempfile +from pathlib import Path -import joblib -import pytorch_lightning as pl +import lightning as L +import numpy as np import torch import cultionet @@ -11,10 +14,11 @@ from cultionet.model import CultionetParams from cultionet.utils.project_paths import setup_paths -pl.seed_everything(100) +L.seed_everything(100) +RNG = np.random.default_rng(200) -def create_data() -> Data: +def create_data(group: int) -> Data: num_channels = 2 num_time = 12 height = 10 @@ -27,7 +31,10 @@ def create_data() -> Data: bdist = torch.rand((1, height, width), dtype=torch.float32) y = torch.randint(low=0, high=3, size=(1, height, width)) - lat_left, lat_bottom, lat_right, lat_top = 1, 2, 3, 4 + lat_left = RNG.uniform(low=-180, high=180) + lat_bottom = RNG.uniform(low=-90, high=90) + lat_right = RNG.uniform(low=-180, high=180) + lat_top = RNG.uniform(low=-90, high=90) batch_data = Data( x=x, @@ -37,53 +44,79 @@ def create_data() -> Data: bottom=torch.tensor([lat_bottom], dtype=torch.float32), right=torch.tensor([lat_right], dtype=torch.float32), top=torch.tensor([lat_top], dtype=torch.float32), + batch_id=[group], ) return batch_data -def test_train(): +# def test_train(): +# num_data = 10 +# with tempfile.TemporaryDirectory() as tmp_path: +# ppaths = setup_paths(tmp_path) +# for i in range(num_data): +# data_path = ( +# ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt' +# ) +# batch_data = create_data(i) +# batch_data.to_file(data_path) + +# dataset = EdgeDataset( +# ppaths.train_path, +# processes=0, +# threads_per_worker=1, +# random_seed=100, +# ) + +# cultionet_params = CultionetParams( +# ckpt_file=ppaths.ckpt_file, +# model_name="cultionet", +# dataset=dataset, +# val_frac=0.2, +# batch_size=2, +# load_batch_workers=0, +# hidden_channels=16, +# num_classes=2, +# edge_class=2, +# model_type=ModelTypes.TOWERUNET, +# res_block_type=ResBlockTypes.RESA, +# attention_weights=AttentionTypes.SPATIAL_CHANNEL, +# activation_type="SiLU", +# dilations=[1, 2], +# dropout=0.2, +# deep_supervision=True, +# pool_attention=False, +# pool_by_max=True, +# repeat_resa_kernel=False, +# batchnorm_first=True, +# epochs=1, +# device="cpu", +# devices=1, +# precision="16-mixed", +# ) +# cultionet.fit(cultionet_params) + + +def test_train_cli(): num_data = 10 - with tempfile.TemporaryDirectory() as tmp_path: + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) ppaths = setup_paths(tmp_path) for i in range(num_data): data_path = ( ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt' ) - batch_data = create_data() + batch_data = create_data(i) batch_data.to_file(data_path) - dataset = EdgeDataset( - ppaths.train_path, - processes=0, - threads_per_worker=1, - random_seed=100, - ) - - cultionet_params = CultionetParams( - ckpt_file=ppaths.ckpt_file, - model_name="cultionet", - dataset=dataset, - val_frac=0.2, - batch_size=2, - load_batch_workers=0, - hidden_channels=16, - num_classes=2, - edge_class=2, - model_type=ModelTypes.TOWERUNET, - res_block_type=ResBlockTypes.RESA, - attention_weights=AttentionTypes.SPATIAL_CHANNEL, - activation_type="SiLU", - dilations=[1, 2], - dropout=0.2, - deep_supervision=True, - pool_attention=False, - pool_by_max=True, - repeat_resa_kernel=False, - batchnorm_first=True, - epochs=1, - device="cpu", - devices=1, - precision="16-mixed", - ) - cultionet.fit(cultionet_params) + with open(tmp_path / "data/classes.info", "w") as f: + json.dump({"max_crop_class": 1, "edge_class": 2}, f) + + command = f"cultionet train -p {str(tmp_path.absolute())} --val-frac 0.2 --augment-prob 0.5 --epochs 2 --hidden-channels 16 --processes 1 --load-batch-workers 0 --batch-size 2 --dropout 0.2 --deep-sup --dilations 1 2 --pool-by-max --learning-rate 0.01 --weight-decay 1e-4 --attention-weights natten" + + try: + subprocess.run( + command, shell=True, check=True, capture_output=True + ) + except subprocess.CalledProcessError as e: + raise NameError(e.stderr) from e