From eaaea17e43400ce497ddd38ebe3d2a0d1c4d3dd1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 27 Mar 2022 16:18:46 -0700 Subject: [PATCH] complete LogAvgExp2D --- README.md | 14 ++++++- logavgexp_pytorch/__init__.py | 2 +- logavgexp_pytorch/logavgexp_pytorch.py | 52 ++++++++++++++++++++++++++ setup.py | 3 +- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index efc8574..e3bfe5c 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ from logavgexp_pytorch import logavgexp learned_temp = nn.Parameter(torch.ones(1) * -5).exp().clamp(min = 1e-8) # make sure temperature can't hit 0 x = torch.randn(1, 2048, 5) -y = logavgexp(x, temp = learned_temp, dim = -1) # (1, 5) +y = logavgexp(x, temp = learned_temp, dim = 1) # (1, 5) ``` Or you can use the `LogAvgExp` class to handle the learned temperature parameter @@ -69,6 +69,18 @@ x = torch.randn(1, 2048, 5) y = logavgexp(x) # (1, 5) ``` +## LogAvgExp2D + +```python +import torch +from logavgexp_pytorch import LogAvgExp2D + +logavgexp_pool = LogAvgExp2D((2, 2), stride = 2) # (2 x 2) pooling + +img = torch.randn(1, 16, 64, 64) +out = logavgexp_pool(img) # (1, 16, 32, 32) +``` + ## Citations ```bibtex diff --git a/logavgexp_pytorch/__init__.py b/logavgexp_pytorch/__init__.py index 8cbfdef..403668d 100644 --- a/logavgexp_pytorch/__init__.py +++ b/logavgexp_pytorch/__init__.py @@ -1 +1 @@ -from logavgexp_pytorch.logavgexp_pytorch import logavgexp, LogAvgExp +from logavgexp_pytorch.logavgexp_pytorch import logavgexp, LogAvgExp, LogAvgExp2D diff --git a/logavgexp_pytorch/logavgexp_pytorch.py b/logavgexp_pytorch/logavgexp_pytorch.py index c5dcd40..69054b3 100644 --- a/logavgexp_pytorch/logavgexp_pytorch.py +++ b/logavgexp_pytorch/logavgexp_pytorch.py @@ -3,12 +3,24 @@ from torch import nn import torch.nn.functional as F +from einops import rearrange + +# helper functions + def exists(t): return t is not None def log(t, eps = 1e-20): return torch.log(t + eps) +def cast_tuple(t, length = 1): + return t if isinstance(t, tuple) else ((t,) * length) + +def calc_conv_output(shape, kernel_size, padding, stride): + return tuple(map(lambda x: int((x[0] - x[1] + 2 * x[2]) / x[3] + 1), zip(shape, kernel_size, padding, stride))) + +# main function + def logavgexp( t, mask = None, @@ -36,6 +48,8 @@ def logavgexp( out = out.unsqueeze(dim) if keepdim else out return out +# learned temperature - logavgexp class + class LogAvgExp(nn.Module): def __init__( self, @@ -71,3 +85,41 @@ def forward(self, x, mask = None, eps = 1e-8): temp = temp, keepdim = self.keepdim ) + +# logavgexp 2d + +class LogAvgExp2D(nn.Module): + def __init__( + self, + kernel_size, + *, + padding = 0, + stride = 1, + temp = 0.01, + learned_temp = True, + eps = 1e-20, + **kwargs + ): + super().__init__() + self.padding = cast_tuple(padding, 2) + self.stride = cast_tuple(stride, 2) + self.kernel_size = cast_tuple(kernel_size, 2) + + self.unfold = nn.Unfold(self.kernel_size, padding = self.padding, stride = self.stride) + self.logavgexp = LogAvgExp(dim = -1, eps = eps, learned_temp = learned_temp, temp = temp) + + def forward(self, x): + """ + b - batch + c - channels + h - height + w - width + j - reducing dimension + """ + + b, c, h, w = x.shape + out_h, out_w = calc_conv_output((h, w), self.kernel_size, self.padding, self.stride) + + x = self.unfold(x) + x = rearrange(x, 'b (c j) (h w) -> b c h w j', h = out_h, w = out_w, c = c) + return self.logavgexp(x) diff --git a/setup.py b/setup.py index 4925795..79b68b0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'logavgexp-pytorch', packages = find_packages(exclude=[]), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'LogAvgExp - Pytorch', author = 'Phil Wang', @@ -16,6 +16,7 @@ 'logsumexp' ], install_requires=[ + 'einops>=0.4.1', 'torch>=1.6' ], classifiers=[