Skip to content

Commit

Permalink
correctly attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 14, 2021
1 parent 9be8eed commit d9d0a18
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 36 deletions.
31 changes: 5 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ model = gMLP(
dim = 512,
depth = 6,
seq_len = 256,
act = nn.Tanh() # activation for spatial gate (defaults to identity)
circulant_matrix = True, # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn.Tanh() # activation for spatial gate (defaults to identity)
)

x = torch.randint(0, 20000, (1, 256))
Expand Down Expand Up @@ -89,29 +90,7 @@ pred = model(img) # (1, 1000)

## Experimental

A independent researcher proposes using circulant matrices in gMLPs in <a href="https://zhuanlan.zhihu.com/p/395005917">a blogpost on Zhihu</a>. This allows you to scale gMLPs with increasing sequence length with linear parameter costs (as opposed to quadratic). My experiments show improved rate of convergence.

You can use it by setting one extra flag to `True`

```python
import torch
from torch import nn
from g_mlp_pytorch import gMLP

model = gMLP(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 256,
causal = True,
use_circulant_matrix = True # set to True
)

x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)
```

Finally, you can also use multi-headedness, as proposed by Peng Bo in the blogpost. To do so, just set `heads` to be greater than `1`
A independent researcher proposes using a multi-headed approach for gMLPs in <a href="https://zhuanlan.zhihu.com/p/395005917">a blogpost on Zhihu</a>. To do so, just set `heads` to be greater than `1`

```python
import torch
Expand All @@ -124,8 +103,8 @@ model = gMLP(
depth = 6,
seq_len = 256,
causal = True,
heads = 4, # 4 heads
use_circulant_matrix = True # set to True
circulant_matrix = True,
heads = 4 # 4 heads
)

x = torch.randint(0, 20000, (1, 256))
Expand Down
18 changes: 9 additions & 9 deletions g_mlp_pytorch/g_mlp_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
act = nn.Identity(),
heads = 1,
init_eps = 1e-3,
use_circulant_matrix = False
circulant_matrix = False
):
super().__init__()
dim_out = dim // 2
Expand All @@ -92,12 +92,12 @@ def __init__(

# parameters

if use_circulant_matrix:
if circulant_matrix:
self.circulant_pos_x = nn.Parameter(torch.ones(heads, dim_seq))
self.circulant_pos_y = nn.Parameter(torch.ones(heads, dim_seq))

self.use_circulant_matrix = use_circulant_matrix
shape = (heads, dim_seq,) if use_circulant_matrix else (heads, dim_seq, dim_seq)
self.circulant_matrix = circulant_matrix
shape = (heads, dim_seq,) if circulant_matrix else (heads, dim_seq, dim_seq)
weight = torch.zeros(shape)

self.weight = nn.Parameter(weight)
Expand All @@ -114,7 +114,7 @@ def forward(self, x, gate_res = None):

weight, bias = self.weight, self.bias

if self.use_circulant_matrix:
if self.circulant_matrix:
# build the circulant matrix

dim_seq = weight.shape[-1]
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(
attn_dim = None,
causal = False,
act = nn.Identity(),
use_circulant_matrix = False
circulant_matrix = False
):
super().__init__()
self.proj_in = nn.Sequential(
Expand All @@ -167,7 +167,7 @@ def __init__(

self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else None

self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act, heads, use_circulant_matrix = use_circulant_matrix)
self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act, heads, circulant_matrix = circulant_matrix)
self.proj_out = nn.Linear(dim_ff // 2, dim)

def forward(self, x):
Expand All @@ -194,7 +194,7 @@ def __init__(
prob_survival = 1.,
causal = False,
act = nn.Identity(),
use_circulant_matrix = False
circulant_matrix = False
):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
Expand All @@ -205,7 +205,7 @@ def __init__(

self.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()

self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act, use_circulant_matrix = use_circulant_matrix))) for i in range(depth)])
self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act, circulant_matrix = circulant_matrix))) for i in range(depth)])

self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'g-mlp-pytorch',
packages = find_packages(),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'gMLP - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d9d0a18

Please sign in to comment.