From e2ae4d756436398ffc5b123cb0932d239f7f933c Mon Sep 17 00:00:00 2001 From: Seyed Alireza Fatemi Jahromi Date: Sun, 8 Oct 2023 12:49:59 +0000 Subject: [PATCH] use bool masks --- src/open_clip/transformer.py | 2 +- tests/test_hf_model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 0a30e9466..ce5e0d3f7 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -586,7 +586,7 @@ def build_attention_mask(self): def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) - cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index 79df2f2cf..f9191f1f4 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -8,8 +8,8 @@ def test_poolers(): bs, sl, d = 2, 10, 5 h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d) - mask = torch.ones(bs, sl, dtype=torch.long) - mask[:2, 6:] = 0 + mask = torch.ones(bs, sl, dtype=torch.bool) + mask[:2, 6:] = False x = BaseModelOutput(h) for name, cls in _POOLERS.items(): pooler = cls()