Skip to content

Commit

Permalink
take care of masking, if data requires it
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 5, 2021
1 parent f0455b6 commit 27c7a13
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
14 changes: 11 additions & 3 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 75, drop
nn.Dropout(dropout)
)

def forward(self, x, context = None):
def forward(self, x, context = None, mask = None):
h = self.heads

q = self.to_q(x)
Expand All @@ -93,7 +93,15 @@ def forward(self, x, context = None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h = h)
sim.masked_fill_(~mask, max_neg_value)

# attention, what we cannot get enough of
attn = sim.softmax(dim = -1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)
Expand Down Expand Up @@ -143,14 +151,14 @@ def __init__(

self.to_logits = nn.Linear(latent_dim, num_classes)

def forward(self, data):
def forward(self, data, mask = None):
b = data.shape[0]
data = fourier_encode(data, self.num_fourier_features)

x = repeat(self.latents, 'n d -> b n d', b = b)

for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
x = cross_attn(x, context = data) + x
x = cross_attn(x, context = data, mask = mask) + x
x = cross_ff(x) + x
x = latent_attn(x) + x
x = latent_ff(x) + x
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Perceiver - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 27c7a13

Please sign in to comment.