Skip to content
22 changes: 19 additions & 3 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding
from monai.networks.layers import Conv, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"}


class PatchEmbeddingBlock(nn.Module):
Expand All @@ -53,6 +54,7 @@ def __init__(
pos_embed_type: str = "learnable",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
pos_embed_kwargs: Optional[dict] = None,
) -> None:
"""
Args:
Expand All @@ -65,6 +67,8 @@ def __init__(
pos_embed_type: position embedding layer type.
dropout_rate: fraction of the input units to drop.
spatial_dims: number of spatial dimensions.
pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain
`temperature` and for fourier it can contain `scales`.
"""

super().__init__()
Expand Down Expand Up @@ -105,6 +109,8 @@ def __init__(
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)

pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs

if self.pos_embed_type == "none":
pass
elif self.pos_embed_type == "learnable":
Expand All @@ -114,7 +120,17 @@ def __init__(
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
self.position_embeddings = build_sincos_position_embedding(
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
)
elif self.pos_embed_type == "fourier":
grid_size = []
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

self.position_embeddings = build_fourier_position_embedding(
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
)
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")

Expand Down
55 changes: 54 additions & 1 deletion monai/networks/blocks/pos_embed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn

__all__ = ["build_sincos_position_embedding"]
__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]


# From PyTorch internals
Expand All @@ -32,6 +32,59 @@ def parse(x):
return parse


def build_fourier_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
) -> torch.nn.Parameter:
"""
Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension,
spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
points more distinguishable.
Position embedding is made anistropic by allowing setting different scales for each spatial dimension.
Reference: https://arxiv.org/abs/2509.02488

Args:
grid_size (int | List[int]): The size of the grid in each spatial dimension.
embed_dim (int): The dimension of the embedding.
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
scales (float | List[float]): The scale for every spatial dimension. If a single float is provided,
the same scale is used for all dimensions.

Returns:
pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
"""

to_tuple = _ntuple(spatial_dims)
grid_size_t = to_tuple(grid_size)
if len(grid_size_t) != spatial_dims:
raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.")

if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even for Fourier position embedding")

# Ensure scales is a tensor of shape (spatial_dims,)
if isinstance(scales, float):
scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float)
elif isinstance(scales, (list, tuple)):
if len(scales) != spatial_dims:
raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}")
scales_tensor = torch.tensor(scales, dtype=torch.float)
else:
raise TypeError(f"scales must be float or list of floats, got {type(scales)}")

gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor

position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t]
positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1)
positions = positions.flatten(end_dim=-2)

x_proj = (2.0 * torch.pi * positions) @ gaussians.T

pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False)

return pos_emb


def build_sincos_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
) -> torch.nn.Parameter:
Expand Down
39 changes: 39 additions & 0 deletions tests/networks/blocks/test_patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ def test_sincos_pos_embed(self):

self.assertEqual(net.position_embeddings.requires_grad, False)

def test_fourier_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
img_size=(32, 32, 32),
patch_size=(8, 8, 8),
hidden_size=96,
num_heads=8,
pos_embed_type="fourier",
dropout_rate=0.5,
)

self.assertEqual(net.position_embeddings.requires_grad, False)

def test_learnable_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
Expand All @@ -101,6 +114,32 @@ def test_learnable_pos_embed(self):
self.assertEqual(net.position_embeddings.requires_grad, True)

def test_ill_arg(self):
with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
img_size=(128, 128, 128),
patch_size=(16, 16, 16),
hidden_size=128,
num_heads=12,
proj_type="conv",
dropout_rate=0.1,
pos_embed_type="fourier",
pos_embed_kwargs=dict(scales=[1.0, 1.0]),
)

with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
img_size=(128, 128),
patch_size=(16, 16),
hidden_size=128,
num_heads=12,
proj_type="conv",
dropout_rate=0.1,
pos_embed_type="fourier",
pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]),
)

with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
Expand Down
Loading