From 6118e388b8127df9911dad9ab9385f88d2613113 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 29 Aug 2022 11:15:27 +0200 Subject: [PATCH 1/2] Add dtype --- src/transformers/models/donut/modeling_donut_swin.py | 6 +++--- src/transformers/models/swin/modeling_swin.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 78e5cc81c19885..4f23d979361501 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -538,10 +538,10 @@ def set_shift_and_window_size(self, input_resolution): self.shift_size = 0 self.window_size = min(input_resolution) - def get_attn_mask(self, height, width): + def get_attn_mask(self, height, width, dtype): if self.shift_size > 0: # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1)) + img_mask = torch.zeros((1, height, width, 1), dtype=dtype) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), @@ -600,7 +600,7 @@ def forward( # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, width_pad) + attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) if attn_mask is not None: attn_mask = attn_mask.to(hidden_states_windows.device) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 48c9b8cccf9ec0..58d01d1cdfd45d 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -604,10 +604,10 @@ def set_shift_and_window_size(self, input_resolution): self.shift_size = 0 self.window_size = min(input_resolution) - def get_attn_mask(self, height, width): + def get_attn_mask(self, height, width, dtype): if self.shift_size > 0: # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1)) + img_mask = torch.zeros((1, height, width, 1), dtype=dtype) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), @@ -666,7 +666,7 @@ def forward( # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, width_pad) + attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) if attn_mask is not None: attn_mask = attn_mask.to(hidden_states_windows.device) From c5a4d3b41f4a999ec30f40cd1a0259ff6fe3e46a Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 29 Aug 2022 11:21:39 +0200 Subject: [PATCH 2/2] Fix Swinv2 as well --- src/transformers/models/swinv2/modeling_swinv2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 52f836d5b91d38..890530691dd3a6 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -676,10 +676,10 @@ def set_shift_and_window_size(self, input_resolution): else target_shift_size[0] ) - def get_attn_mask(self, height, width): + def get_attn_mask(self, height, width, dtype): if self.shift_size > 0: # calculate attention mask for shifted window multihead self attention - img_mask = torch.zeros((1, height, width, 1)) + img_mask = torch.zeros((1, height, width, 1), dtype=dtype) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), @@ -736,7 +736,7 @@ def forward( # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, width_pad) + attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) if attn_mask is not None: attn_mask = attn_mask.to(hidden_states_windows.device)