Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add empirical sampling mode to the RCF model #1339

Merged
merged 12 commits into from
Sep 29, 2023
30 changes: 23 additions & 7 deletions tests/models/test_rcf.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
import torch

from torchgeo.datasets import EuroSAT
from torchgeo.models import RCF


class TestRCF:
def test_in_channels(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
x = torch.randn(2, 5, 64, 64)
model(x)

model = RCF(in_channels=3, features=4, kernel_size=3)
model = RCF(in_channels=3, features=4, kernel_size=3, mode="gaussian")
match = "to have 3 channels, but got 5 channels instead"
with pytest.raises(RuntimeError, match=match):
model(x)

def test_num_features(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
x = torch.randn(2, 5, 64, 64)
y = model(x)
assert y.shape[1] == 4
Expand All @@ -29,14 +32,27 @@ def test_num_features(self) -> None:
assert y.shape[0] == 4

def test_untrainable(self) -> None:
model = RCF(in_channels=5, features=4, kernel_size=3)
model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
assert len(list(model.parameters())) == 0

def test_biases(self) -> None:
model = RCF(features=24, bias=10)
model = RCF(features=24, bias=10, mode="gaussian")
assert torch.all(model.biases == 10)

def test_seed(self) -> None:
weights1 = RCF(seed=1).weights
weights2 = RCF(seed=1).weights
weights1 = RCF(seed=1, mode="gaussian").weights
weights2 = RCF(seed=1, mode="gaussian").weights
assert torch.allclose(weights1, weights2)

def test_empirical(self) -> None:
root = os.path.join("tests", "data", "eurosat")
ds = EuroSAT(root=root, bands=EuroSAT.rgb_bands, split="train")
model = RCF(
in_channels=3, features=4, kernel_size=3, mode="empirical", dataset=ds
)
model(torch.randn(2, 3, 8, 8))

def test_empirical_no_dataset(self) -> None:
match = "dataset must be provided when mode is 'empirical'"
with pytest.raises(ValueError, match=match):
RCF(mode="empirical", dataset=None)
117 changes: 107 additions & 10 deletions torchgeo/models/rcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,33 @@

from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Module

from ..datasets import NonGeoDataset


class RCF(Module):
"""This model extracts random convolutional features (RCFs) from its input.

RCFs are used in Multi-task Observation using Satellite Imagery & Kitchen Sinks
(MOSAIKS) method proposed in https://www.nature.com/articles/s41467-021-24638-z.
RCFs are used in the Multi-task Observation using Satellite Imagery & Kitchen Sinks
(MOSAIKS) method proposed in "A generalizable and accessible approach to machine
learning with global satellite imagery".

This class can operate in two modes, "gaussian" and "empirical". In "gaussian" mode,
the filters will be sampled from a Gaussian distribution, while in "empirical" mode,
the filters will be sampled from a dataset.

If you use this model in your research, please cite the following paper:

* https://www.nature.com/articles/s41467-021-24638-z

.. note::

This Module is *not* trainable. It is only used as a feature extractor.
This Module is *not* trainable. It is only used as a feature extractor.
"""

weights: Tensor
Expand All @@ -32,6 +44,8 @@ def __init__(
kernel_size: int = 3,
bias: float = -1.0,
seed: Optional[int] = None,
mode: str = "gaussian",
dataset: Optional[NonGeoDataset] = None,
) -> None:
"""Initializes the RCF model.

Expand All @@ -41,21 +55,28 @@ def __init__(
.. versionadded:: 0.2
The *seed* parameter.

.. versionadded:: 0.5
The *mode* and *dataset* parameters.

Args:
in_channels: number of input channels
features: number of features to compute, must be divisible by 2
kernel_size: size of the kernel used to compute the RCFs
bias: bias of the convolutional layer
seed: random seed used to initialize the convolutional layer
mode: "empirical" or "gaussian"
dataset: a NonGeoDataset to sample from when mode is "empirical"
"""
super().__init__()

assert mode in ["empirical", "gaussian"]
if mode == "empirical" and dataset is None:
raise ValueError("dataset must be provided when mode is 'empirical'")
assert features % 2 == 0
num_patches = features // 2

if seed is None:
generator = None
else:
generator = torch.Generator().manual_seed(seed)
generator = torch.Generator()
if seed:
generator = generator.manual_seed(seed)

# We register the weight and bias tensors as "buffers". This does two things:
# makes them behave correctly when we call .to(...) on the module, and makes
Expand All @@ -64,7 +85,7 @@ def __init__(
self.register_buffer(
"weights",
torch.randn(
features // 2,
num_patches,
in_channels,
kernel_size,
kernel_size,
Expand All @@ -73,9 +94,85 @@ def __init__(
),
)
self.register_buffer(
"biases", torch.zeros(features // 2, requires_grad=False) + bias
"biases", torch.zeros(num_patches, requires_grad=False) + bias
)

if mode == "empirical":
assert dataset is not None
num_channels, height, width = dataset[0]["image"].shape
assert num_channels == in_channels
patches = np.zeros(
(num_patches, num_channels, kernel_size, kernel_size), dtype=np.float32
)
idxs = torch.randint(
0, len(dataset), (num_patches,), generator=generator
).numpy()
ys = torch.randint(
0, height - kernel_size, (num_patches,), generator=generator
).numpy()
xs = torch.randint(
0, width - kernel_size, (num_patches,), generator=generator
).numpy()

for i in range(num_patches):
img = dataset[idxs[i]]["image"]
patches[i] = img[
:, ys[i] : ys[i] + kernel_size, xs[i] : xs[i] + kernel_size
]

patches = self._normalize(patches)
self.weights = torch.tensor(patches)

def _normalize(
self,
patches: "np.typing.NDArray[np.float32]",
min_divisor: float = 1e-8,
zca_bias: float = 0.001,
) -> "np.typing.NDArray[np.float32]":
"""Does ZCA whitening on a set of input patches.

Copied from https://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120

Args:
patches: a numpy array of size (N, C, H, W)
min_divisor: a small number to guard against division by zero
zca_bias: bias term for ZCA whitening

Returns
a numpy array of size (N, C, H, W) containing the normalized patches

.. versionadded:: 0.5
""" # noqa: E501
n_patches = patches.shape[0]
orig_shape = patches.shape
patches = patches.reshape(patches.shape[0], -1)

# Zero mean every feature
patches = patches - np.mean(patches, axis=1, keepdims=True)

# Normalize
patch_norms = np.linalg.norm(patches, axis=1)

# Get rid of really small norms
patch_norms[np.where(patch_norms < min_divisor)] = 1

# Make features unit norm
patches = patches / patch_norms[:, np.newaxis]

patchesCovMat = 1.0 / n_patches * patches.T.dot(patches)

(E, V) = np.linalg.eig(patchesCovMat)

E += zca_bias
sqrt_zca_eigs = np.sqrt(E)
inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1))
global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T)
patches_normalized: "np.typing.NDArray[np.float32]" = (
(patches).dot(global_ZCA).dot(global_ZCA.T)
)

return patches_normalized.reshape(orig_shape).astype("float32")

def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the RCF model.

Expand Down