Skip to content

Commit

Permalink
Fix practically all the problems
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed May 15, 2023
1 parent c2d2403 commit 3282a70
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions torchgeo/models/rcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

"""Implementation of a random convolutional feature projection model."""

from typing import Any, Optional
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Module
from torch.utils.data import Dataset

from ..datasets import NonGeoDataset


class RCF(Module):
Expand Down Expand Up @@ -109,12 +110,27 @@ def forward(self, x: Tensor) -> Tensor:


class MOSAIKS(RCF):
"""This model extracts MOSAIKS features from its input.
MOSAIKS features are described in Multi-task Observation using Satellite Imagery &
Kitchen Sinks https://www.nature.com/articles/s41467-021-24638-z. Briefly, this
model is instantiated with a dataset, samples patches from the dataset, ZCA whitens
the patches, then uses those as convolutional filters to extract features with.
.. note::
This Module is *not* trainable. It is only used as a feature extractor.
"""

def _normalize(
self, patches: np.ndarray, min_divisor: float = 1e-8, zca_bias: float = 0.001
) -> np.ndarray:
self,
patches: np.typing.NDArray[np.float32],
min_divisor: float = 1e-8,
zca_bias: float = 0.001,
) -> np.typing.NDArray[np.float32]:
"""Does ZCA whitening on a set of input patches.
Copied fromhttps://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120
Copied from https://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120
Args:
patches: a numpy array of size (N, C, H, W)
Expand All @@ -123,11 +139,7 @@ def _normalize(
Returns
a numpy array of size (N, C, H, W) containing the normalized patches
"""
if patches.dtype == "uint8":
patches = patches.astype("float32")
patches /= 255.0

""" # noqa: E501
n_patches = patches.shape[0]
orig_shape = patches.shape
patches = patches.reshape(patches.shape[0], -1)
Expand All @@ -152,13 +164,15 @@ def _normalize(
sqrt_zca_eigs = np.sqrt(E)
inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1))
global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T)
patches_normalized: np.ndarray = (patches).dot(global_ZCA).dot(global_ZCA.T)
patches_normalized: np.typing.NDArray[np.float32] = (
(patches).dot(global_ZCA).dot(global_ZCA.T)
)

return patches_normalized.reshape(orig_shape).astype("float32")

def __init__(
self,
dataset: Dataset[dict[str, Any]],
dataset: NonGeoDataset,
in_channels: int = 4,
features: int = 16,
kernel_size: int = 3,
Expand Down

0 comments on commit 3282a70

Please sign in to comment.