-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🔀 Merge branch 'main' into extra-data-and-pre-batch-shuffle
Commented out the extra california_*.hdf5 data for now.
- Loading branch information
Showing
14 changed files
with
3,075 additions
and
2,011 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from lightning.pytorch.callbacks import Callback | ||
import torch | ||
import torch.nn.functional as F | ||
import wandb | ||
|
||
|
||
class LogIntermediatePredictions(Callback): | ||
"""Visualize the model results at the end of every epoch.""" | ||
|
||
def __init__(self, logger): | ||
"""Instantiates with wandb-logger. | ||
Args: | ||
logger : wandb-logger instance. | ||
""" | ||
super().__init__() | ||
self.logger = logger | ||
|
||
def on_validation_batch_end( | ||
self, | ||
trainer, | ||
pl_module, | ||
outputs, | ||
batch, | ||
batch_idx, | ||
dataloader_idx=0, | ||
): | ||
"""Called when the validation batch ends. | ||
At the end of each epoch, takes a sample from validation dataset & logs | ||
the image with model predictions to wandb-logger for humans to interpret | ||
how model evolves over time. | ||
""" | ||
if batch_idx == 0: | ||
# Take a small sample size for logging | ||
id2label = {0: "ok", 1: "burn"} | ||
log_list = [] | ||
|
||
with torch.no_grad(): | ||
pre_img, post_img, mask, metadata = batch | ||
batch_size = mask.shape[0] | ||
|
||
# Pass the image through neural network model to get predicted images | ||
logits: torch.Tensor = pl_module(x1=pre_img, x2=post_img).squeeze() | ||
y_pred: torch.Tensor = F.sigmoid(logits) | ||
y_pred = (y_pred > 0.5).int().detach().cpu().numpy() | ||
|
||
for i in range(batch_size): | ||
log_image = wandb.Image( | ||
post_img[i].permute(1, 2, 0).detach().cpu().numpy() / 6000, | ||
masks={ | ||
"prediction": { | ||
"mask_data": mask[i].detach().cpu().numpy(), | ||
"class_labels": id2label, | ||
}, | ||
"ground_truth": { | ||
"mask_data": y_pred[i], | ||
"class_labels": id2label, | ||
}, | ||
}, | ||
) | ||
log_list.append(log_image) | ||
|
||
wandb.log({"predictions": log_list}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
Modular block layers of the TinyCD model. | ||
Reference: | ||
- https://github.com/AndreaCodegoni/Tiny_model_4_CD/blob/main/models/layers.py | ||
- Codegoni, A., Lombardi, G., & Ferrari, A. (2022). TINYCD: A (Not So) Deep | ||
Learning Model For Change Detection (arXiv:2207.13159). arXiv. | ||
https://doi.org/10.48550/arXiv.2207.13159 | ||
""" | ||
from typing import List, Optional | ||
|
||
from torch import Tensor, reshape, stack | ||
from torch.nn import Conv2d, InstanceNorm2d, Module, PReLU, Sequential, Upsample | ||
|
||
|
||
class PixelwiseLinear(Module): | ||
def __init__( | ||
self, | ||
fin: List[int], | ||
fout: List[int], | ||
last_activation: Module = None, | ||
) -> None: | ||
assert len(fout) == len(fin) | ||
super().__init__() | ||
|
||
n = len(fin) | ||
self._linears = Sequential( | ||
*[ | ||
Sequential( | ||
Conv2d(fin[i], fout[i], kernel_size=1, bias=True), | ||
PReLU() | ||
if i < n - 1 or last_activation is None | ||
else last_activation, | ||
) | ||
for i in range(n) | ||
] | ||
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
# Processing the tensor: | ||
return self._linears(x) | ||
|
||
|
||
class MixingBlock(Module): | ||
def __init__( | ||
self, | ||
ch_in: int, | ||
ch_out: int, | ||
): | ||
super().__init__() | ||
self._convmix = Sequential( | ||
Conv2d(ch_in, ch_out, 3, groups=ch_out, padding=1), | ||
PReLU(), | ||
InstanceNorm2d(ch_out), | ||
) | ||
|
||
def forward(self, x: Tensor, y: Tensor) -> Tensor: | ||
# Packing the tensors and interleaving the channels: | ||
mixed = stack((x, y), dim=2) | ||
mixed = reshape(mixed, (x.shape[0], -1, x.shape[2], x.shape[3])) | ||
|
||
# Mixing: | ||
return self._convmix(mixed) | ||
|
||
|
||
class MixingMaskAttentionBlock(Module): | ||
"""use the grouped convolution to make a sort of attention""" | ||
|
||
def __init__( | ||
self, | ||
ch_in: int, | ||
ch_out: int, | ||
fin: List[int], | ||
fout: List[int], | ||
generate_masked: bool = False, | ||
): | ||
super().__init__() | ||
self._mixing = MixingBlock(ch_in, ch_out) | ||
self._linear = PixelwiseLinear(fin, fout) | ||
self._final_normalization = InstanceNorm2d(ch_out) if generate_masked else None | ||
self._mixing_out = MixingBlock(ch_in, ch_out) if generate_masked else None | ||
|
||
def forward(self, x: Tensor, y: Tensor) -> Tensor: | ||
z_mix = self._mixing(x, y) | ||
z = self._linear(z_mix) | ||
z_mix_out = 0 if self._mixing_out is None else self._mixing_out(x, y) | ||
|
||
return ( | ||
z | ||
if self._final_normalization is None | ||
else self._final_normalization(z_mix_out * z) | ||
) | ||
|
||
|
||
class UpMask(Module): | ||
def __init__( | ||
self, | ||
scale_factor: float, | ||
nin: int, | ||
nout: int, | ||
): | ||
super().__init__() | ||
self._upsample = Upsample( | ||
scale_factor=scale_factor, mode="bilinear", align_corners=True | ||
) | ||
self._convolution = Sequential( | ||
Conv2d(nin, nin, 3, 1, groups=nin, padding=1), | ||
PReLU(), | ||
InstanceNorm2d(nin), | ||
Conv2d(nin, nout, kernel_size=1, stride=1), | ||
PReLU(), | ||
InstanceNorm2d(nout), | ||
) | ||
|
||
def forward(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor: | ||
x = self._upsample(x) | ||
if y is not None: | ||
x = x * y | ||
return self._convolution(x) |
Oops, something went wrong.