From ae4d88a5b2748f05d717ce91d6cbdcaba499eea4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 11 Nov 2021 22:01:26 -0800 Subject: [PATCH] 1.1.4 --- dalle_pytorch/attention.py | 2 +- dalle_pytorch/transformer.py | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dalle_pytorch/attention.py b/dalle_pytorch/attention.py index 39e807a6..48131e89 100644 --- a/dalle_pytorch/attention.py +++ b/dalle_pytorch/attention.py @@ -26,7 +26,7 @@ def max_neg_value(t): def stable_softmax(t, dim = -1, alpha = 32 ** 2): t = t / alpha - t = t - torch.amax(t, dim = dim, keepdim = True) + t = t - torch.amax(t, dim = dim, keepdim = True).detach() return (t * alpha).softmax(dim = dim) def apply_pos_emb(pos_emb, qkv): diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index 81bc7b4a..6b62df70 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -33,7 +33,7 @@ def __init__(self, dim): self.dim = dim def forward(self, x): - maxes = x.amax(dim = self.dim, keepdim = True) + maxes = x.amax(dim = self.dim, keepdim = True).detach() return x / maxes # https://arxiv.org/abs/2103.17239 diff --git a/setup.py b/setup.py index 29e6396c..99754392 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'dalle-pytorch', packages = find_packages(), include_package_data = True, - version = '1.1.2', + version = '1.1.4', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',