diff --git a/setup.py b/setup.py index 0a4e682..adeab3c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'sinkhorn_transformer', packages = find_packages(exclude=['examples']), - version = '0.11.0', + version = '0.11.1', license='MIT', description = 'Sinkhorn Transformer - Sparse Sinkhorn Attention', author = 'Phil Wang', diff --git a/sinkhorn_transformer/sinkhorn_transformer.py b/sinkhorn_transformer/sinkhorn_transformer.py index 443c969..a3a67c4 100644 --- a/sinkhorn_transformer/sinkhorn_transformer.py +++ b/sinkhorn_transformer/sinkhorn_transformer.py @@ -718,7 +718,7 @@ def __init__(self, num_tokens, dim, max_seq_len, depth, heads = 8, dim_head = No if emb_dim != dim: self.sinkhorn_transformer = ProjectInOut(self.sinkhorn_transformer, emb_dim, dim, project_out =(not return_embeddings)) - self.norm = nn.LayerNorm(dim) + self.norm = nn.LayerNorm(emb_dim) self.to_logits = identity if return_embeddings else nn.Linear(emb_dim, num_tokens) def forward(self, x, **kwargs):