diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index dc0e7131628b89..2f9bd27b0e0006 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -226,9 +226,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int """ Merges windows to produce higher resolution features. """ - x = shape_list(windows)[0] + x = tf.shape(windows)[0] y = tf.cast(height * width / (window_size * window_size), tf.int32) - batch_size = int(x / y) + batch_size = tf.math.floordiv(x, y) windows = tf.reshape( windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) ) @@ -695,16 +695,18 @@ def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: i img_mask = tf.expand_dims(img_mask, -1) img_mask = tf.expand_dims(img_mask, 0) - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size)) + mask_windows = window_partition(img_mask, window_size) + mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size)) attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2) attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask) attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask) return attn_mask - def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]: - pad_right = (self.window_size - width % self.window_size) % self.window_size - pad_bottom = (self.window_size - height % self.window_size) % self.window_size + def maybe_pad( + self, hidden_states: tf.Tensor, window_size: int, height: int, width: int + ) -> Tuple[tf.Tensor, tf.Tensor]: + pad_right = (window_size - width % window_size) % window_size + pad_bottom = (window_size - height % window_size) % window_size pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]] hidden_states = tf.pad(hidden_states, pad_values) pad_values = tf.reshape(pad_values, (-1,)) @@ -730,7 +732,7 @@ def call( hidden_states = self.layernorm_before(hidden_states, training=training) hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels)) # pad hidden_states to multiples of window size - hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width) _, height_pad, width_pad, _ = shape_list(hidden_states) # cyclic shift