Skip to content

Commit

Permalink
complete LogAvgExp2D
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 27, 2022
1 parent c567276 commit eaaea17
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 3 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ from logavgexp_pytorch import logavgexp
learned_temp = nn.Parameter(torch.ones(1) * -5).exp().clamp(min = 1e-8) # make sure temperature can't hit 0

x = torch.randn(1, 2048, 5)
y = logavgexp(x, temp = learned_temp, dim = -1) # (1, 5)
y = logavgexp(x, temp = learned_temp, dim = 1) # (1, 5)
```

Or you can use the `LogAvgExp` class to handle the learned temperature parameter
Expand All @@ -69,6 +69,18 @@ x = torch.randn(1, 2048, 5)
y = logavgexp(x) # (1, 5)
```

## LogAvgExp2D

```python
import torch
from logavgexp_pytorch import LogAvgExp2D

logavgexp_pool = LogAvgExp2D((2, 2), stride = 2) # (2 x 2) pooling

img = torch.randn(1, 16, 64, 64)
out = logavgexp_pool(img) # (1, 16, 32, 32)
```

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion logavgexp_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from logavgexp_pytorch.logavgexp_pytorch import logavgexp, LogAvgExp
from logavgexp_pytorch.logavgexp_pytorch import logavgexp, LogAvgExp, LogAvgExp2D
52 changes: 52 additions & 0 deletions logavgexp_pytorch/logavgexp_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,24 @@
from torch import nn
import torch.nn.functional as F

from einops import rearrange

# helper functions

def exists(t):
return t is not None

def log(t, eps = 1e-20):
return torch.log(t + eps)

def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)

def calc_conv_output(shape, kernel_size, padding, stride):
return tuple(map(lambda x: int((x[0] - x[1] + 2 * x[2]) / x[3] + 1), zip(shape, kernel_size, padding, stride)))

# main function

def logavgexp(
t,
mask = None,
Expand Down Expand Up @@ -36,6 +48,8 @@ def logavgexp(
out = out.unsqueeze(dim) if keepdim else out
return out

# learned temperature - logavgexp class

class LogAvgExp(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -71,3 +85,41 @@ def forward(self, x, mask = None, eps = 1e-8):
temp = temp,
keepdim = self.keepdim
)

# logavgexp 2d

class LogAvgExp2D(nn.Module):
def __init__(
self,
kernel_size,
*,
padding = 0,
stride = 1,
temp = 0.01,
learned_temp = True,
eps = 1e-20,
**kwargs
):
super().__init__()
self.padding = cast_tuple(padding, 2)
self.stride = cast_tuple(stride, 2)
self.kernel_size = cast_tuple(kernel_size, 2)

self.unfold = nn.Unfold(self.kernel_size, padding = self.padding, stride = self.stride)
self.logavgexp = LogAvgExp(dim = -1, eps = eps, learned_temp = learned_temp, temp = temp)

def forward(self, x):
"""
b - batch
c - channels
h - height
w - width
j - reducing dimension
"""

b, c, h, w = x.shape
out_h, out_w = calc_conv_output((h, w), self.kernel_size, self.padding, self.stride)

x = self.unfold(x)
x = rearrange(x, 'b (c j) (h w) -> b c h w j', h = out_h, w = out_w, c = c)
return self.logavgexp(x)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'logavgexp-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'LogAvgExp - Pytorch',
author = 'Phil Wang',
Expand All @@ -16,6 +16,7 @@
'logsumexp'
],
install_requires=[
'einops>=0.4.1',
'torch>=1.6'
],
classifiers=[
Expand Down

0 comments on commit eaaea17

Please sign in to comment.