Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch implementation for P4 and P4M GConv #5

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
31 changes: 31 additions & 0 deletions groupy/gconv/pytorch_gconv/p4_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from groupy.gconv.pytorch_gconv.splitgconv2d import SplitGConv2D
from groupy.gconv.make_gconv_indices import make_c4_z2_indices, \
make_c4_p4_indices, flatten_indices


class P4ConvZ2(SplitGConv2D):

@property
def input_stabilizer_size(self):
return 1

@property
def output_stabilizer_size(self):
return 4

def make_transformation_indices(self, ksize):
return flatten_indices(make_c4_z2_indices(ksize=ksize))


class P4ConvP4(SplitGConv2D):

@property
def input_stabilizer_size(self):
return 4

@property
def output_stabilizer_size(self):
return 4

def make_transformation_indices(self, ksize):
return flatten_indices(make_c4_p4_indices(ksize=ksize))
31 changes: 31 additions & 0 deletions groupy/gconv/pytorch_gconv/p4m_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from groupy.gconv.pytorch_gconv.splitgconv2d import SplitGConv2D
from groupy.gconv.make_gconv_indices import make_d4_z2_indices, \
make_d4_p4m_indices, flatten_indices


class P4MConvZ2(SplitGConv2D):

@property
def input_stabilizer_size(self):
return 1

@property
def output_stabilizer_size(self):
return 8

def make_transformation_indices(self, ksize):
return flatten_indices(make_d4_z2_indices(ksize=ksize))


class P4MConvP4M(SplitGConv2D):

@property
def input_stabilizer_size(self):
return 8

@property
def output_stabilizer_size(self):
return 8

def make_transformation_indices(self, ksize):
return flatten_indices(make_d4_p4m_indices(ksize=ksize))
155 changes: 155 additions & 0 deletions groupy/gconv/pytorch_gconv/splitgconv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as nninit
from torch.autograd import Variable


def _pair(x):
if hasattr(x, '__getitem__'):
return x
else:
return (x, x)


class SplitGConv2D(nn.Module):
"""
Group convolution base class for split plane groups.

A plane group (aka wallpaper group) is a group of distance-preserving
transformations that includes two independent discrete translations.

A group is called split (or symmorphic) if every element in this group can
be written as the composition of an element from the "stabilizer of the
origin" and a translation. The stabilizer of the origin consists of those
transformations in the group that leave the origin fixed. For example, the
stabilizer in the rotation-translation group p4 is the set of rotations
around the origin, which is (isomorphic to) the group C4.

Most plane groups are split, but some include glide-reflection generators;
such groups are not split. For split groups G, the G-conv can be split
into a "filter transform" and "translational convolution" part.

Different subclasses of this class implement the filter transform for
various groups, while this class implements the common functionality.

This PyTorch implementation mimicks the original Chainer implementation.
"""

def __init__(self,
in_channels,
out_channels,
ksize=3,
flat_channels=False,
stride=1,
pad=0,
bias=True,
*args, **kwargs):
"""
:param in_channels:
:param out_channels:
:param ksize:
:param flat_channels
:param stride:
:param pad:
:param bias:
:return:
"""

super(SplitGConv2D, self).__init__(*args, **kwargs)

if not isinstance(ksize, int):
raise TypeError('ksize must be an integer (only square filters '
'are supported).')

self.in_channels = in_channels
self.out_channels = out_channels
self.ksize = ksize
self.stride = _pair(stride)
self.pad = _pair(pad)
self.flat_channels = flat_channels
self.use_bias = bias

self.weight = nn.Parameter(torch.Tensor(self.out_channels,
self.in_channels,
self.input_stabilizer_size,
self.ksize,
self.ksize))
nninit.xavier_normal(self.weight)

if self.use_bias:
self.bias = nn.Parameter(
torch.zeros(self.out_channels))

# Shorthands
ni, no = in_channels, out_channels
nti, nto = self.input_stabilizer_size, self.output_stabilizer_size
n = self.ksize

self.expand_shape = (no, nto, ni, nti * n * n)
self.weight_shape = (no * nto, ni * nti, n, n)
self.weight_flat_shape = (no, 1, ni, nti * n * n)

transform_indices = self._create_indices(self.expand_shape)
self.register_buffer('transform_indices', transform_indices)

def _create_indices(self, expand_shape):
no, nto, ni, r = expand_shape
transform_indices = self.make_transformation_indices(ksize=self.ksize)
transform_indices = transform_indices.astype(np.int64)
transform_indices = transform_indices.reshape(1, nto, 1, r)
transform_indices = torch.from_numpy(transform_indices)
transform_indices = transform_indices.expand(*expand_shape)
return transform_indices

@property
def input_stabilizer_size():
raise NotImplementedError()

@property
def output_stabilizer_size():
raise NotImplementedError()

def make_transformation_indices(self, ksize):
raise NotImplementedError()

def forward(self, x):
# Transform the filters
w_flat_ = self.weight.view(self.weight_flat_shape)
w_flat = w_flat_.expand(*self.expand_shape)
w = torch.gather(w_flat, 3, Variable(self.transform_indices)) \
.view(self.weight_shape)

# If flat_channels is False, we need to flatten the input feature maps
# to have a single 1d feature dimension.
if not self.flat_channels:
batch_size = x.size(0)
in_ny, in_nx = x.size()[-2:]
x = x.view(batch_size,
self.in_channels * self.input_stabilizer_size,
in_ny,
in_nx)

# Perform the 2D convolution
y = F.conv2d(x, w, stride=self.stride, padding=self.pad)

# Unfold the output feature maps
# We do this even if flat_channels is True, because we need to add the
# same bias to each G-feature map
batch_size, _, ny_out, nx_out = y.size()
y = y.view(batch_size, self.out_channels, self.output_stabilizer_size,
ny_out, nx_out)

# Add a bias to each G-feature map
if self.use_bias:
b = self.bias.view(1, self.out_channels, 1, 1, 1)
b = b.expand_as(y)
y = y + b

# Flatten feature channels if needed
if self.flat_channels:
n, nc, ng, nx, ny = y.size()
y = y.view(n, nc * ng, nx, ny)

return y
121 changes: 121 additions & 0 deletions groupy/gconv/pytorch_gconv/test_gconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
import torch
from torch.autograd import Variable


def test_p4_net_equivariance():
from groupy.gfunc import Z2FuncArray, P4FuncArray
import groupy.garray.C4_array as c4a
from groupy.gconv.pytorch_gconv.p4_conv import P4ConvZ2, P4ConvP4

im = np.random.randn(1, 1, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[
P4ConvZ2(in_channels=1, out_channels=2, ksize=3),
P4ConvP4(in_channels=2, out_channels=3, ksize=3)
],
input_array=Z2FuncArray,
output_array=P4FuncArray,
point_group=c4a,
)


def test_p4m_net_equivariance():
from groupy.gfunc import Z2FuncArray, P4MFuncArray
import groupy.garray.D4_array as d4a
from groupy.gconv.pytorch_gconv.p4m_conv import P4MConvZ2, P4MConvP4M

im = np.random.randn(1, 1, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[
P4MConvZ2(in_channels=1, out_channels=2, ksize=3),
P4MConvP4M(in_channels=2, out_channels=3, ksize=3)
],
input_array=Z2FuncArray,
output_array=P4MFuncArray,
point_group=d4a,
)


def test_g_z2_conv_equivariance():
from groupy.gfunc import Z2FuncArray, P4FuncArray, P4MFuncArray
import groupy.garray.C4_array as c4a
import groupy.garray.D4_array as d4a
from groupy.gconv.pytorch_gconv.p4_conv import P4ConvZ2
from groupy.gconv.pytorch_gconv.p4m_conv import P4MConvZ2

im = np.random.randn(1, 1, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[P4ConvZ2(1, 2, 3)],
input_array=Z2FuncArray,
output_array=P4FuncArray,
point_group=c4a,
)

check_equivariance(
im=im,
layers=[P4MConvZ2(1, 2, 3)],
input_array=Z2FuncArray,
output_array=P4MFuncArray,
point_group=d4a,
)


def test_p4_p4_conv_equivariance():
from groupy.gfunc import P4FuncArray
import groupy.garray.C4_array as c4a
from groupy.gconv.pytorch_gconv.p4_conv import P4ConvP4

im = np.random.randn(1, 1, 4, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[P4ConvP4(1, 2, 3)],
input_array=P4FuncArray,
output_array=P4FuncArray,
point_group=c4a,
)


def test_p4m_p4m_conv_equivariance():
from groupy.gfunc import P4MFuncArray
import groupy.garray.D4_array as d4a
from groupy.gconv.pytorch_gconv.p4m_conv import P4MConvP4M

im = np.random.randn(1, 1, 8, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[P4MConvP4M(1, 2, 3)],
input_array=P4MFuncArray,
output_array=P4MFuncArray,
point_group=d4a,
)


def check_equivariance(im, layers, input_array, output_array, point_group):

# Transform the image
f = input_array(im)
g = point_group.rand()
gf = g * f
im1 = gf.v

# Apply layers to both images
im = Variable(torch.from_numpy(im))
im1 = Variable(torch.from_numpy(im1))

fmap = im
fmap1 = im1
for layer in layers:
print(layer)
fmap = layer(fmap)
fmap1 = layer(fmap1)

# Transform the computed feature maps
fmap1_garray = output_array(fmap1.data.numpy())
r_fmap1_data = (g.inv() * fmap1_garray).v

fmap_data = fmap.data.numpy()
assert np.allclose(fmap_data, r_fmap1_data, rtol=1e-5, atol=1e-3)