From 983c6a301fbfc04d1814c6f4e2e98c94121761d2 Mon Sep 17 00:00:00 2001 From: Seunghwan Hong Date: Fri, 29 Jul 2022 13:29:46 +0900 Subject: [PATCH 1/3] Refactor `TFSwinLayer` to increase serving compatibility Signed-off-by: Seunghwan Hong --- .../models/swin/modeling_tf_swin.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index dfedd2d885d7..94b6e4bbaa64 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -222,13 +222,10 @@ def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: return windows -def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor: +def window_reverse(windows: tf.Tensor, batch_size: int, window_size: int, height: int, width: int) -> tf.Tensor: """ Merges windows to produce higher resolution features. """ - x = shape_list(windows)[0] - y = tf.cast(height * width / (window_size * window_size), tf.int32) - batch_size = int(x / y) windows = tf.reshape( windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) ) @@ -688,16 +685,18 @@ def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: i img_mask = tf.expand_dims(img_mask, -1) img_mask = tf.expand_dims(img_mask, 0) - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size)) + mask_windows = window_partition(img_mask, window_size) + mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size)) attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2) attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask) attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask) return attn_mask - def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]: - pad_right = (self.window_size - width % self.window_size) % self.window_size - pad_bottom = (self.window_size - height % self.window_size) % self.window_size + def maybe_pad( + self, hidden_states: tf.Tensor, height: int, width: int, window_size: int + ) -> Tuple[tf.Tensor, tf.Tensor]: + pad_right = (window_size - width % window_size) % window_size + pad_bottom = (window_size - height % window_size) % window_size pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]] hidden_states = tf.pad(hidden_states, pad_values) pad_values = tf.reshape(pad_values, (-1,)) @@ -746,7 +745,7 @@ def call( attention_output = attention_outputs[0] attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels)) - shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad) + shifted_windows = window_reverse(attention_windows, batch_size, window_size, height_pad, width_pad) # reverse cyclic shift if shift_size > 0: From 79d9d220ccd9df62a86d130bc19d4cfc4606da4d Mon Sep 17 00:00:00 2001 From: Seunghwan Hong Date: Fri, 29 Jul 2022 14:12:40 +0900 Subject: [PATCH 2/3] Fix missed parameters while refactoring Signed-off-by: Seunghwan Hong --- src/transformers/models/swin/modeling_tf_swin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 94b6e4bbaa64..f6c05d2f7573 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -693,7 +693,7 @@ def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: i return attn_mask def maybe_pad( - self, hidden_states: tf.Tensor, height: int, width: int, window_size: int + self, hidden_states: tf.Tensor, window_size: int, height: int, width: int ) -> Tuple[tf.Tensor, tf.Tensor]: pad_right = (window_size - width % window_size) % window_size pad_bottom = (window_size - height % window_size) % window_size @@ -722,7 +722,7 @@ def call( hidden_states = self.layernorm_before(hidden_states, training=training) hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels)) # pad hidden_states to multiples of window size - hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width) _, height_pad, width_pad, _ = shape_list(hidden_states) # cyclic shift From 768381a48262b7988c37a52500e197c4c6382d98 Mon Sep 17 00:00:00 2001 From: Seunghwan Hong Date: Thu, 4 Aug 2022 12:17:00 +0900 Subject: [PATCH 3/3] Fix window_reverse to calculate batch size Signed-off-by: Seunghwan Hong Co-Authored-By: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/swin/modeling_tf_swin.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index f6c05d2f7573..a574d8ea24c7 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -222,10 +222,13 @@ def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: return windows -def window_reverse(windows: tf.Tensor, batch_size: int, window_size: int, height: int, width: int) -> tf.Tensor: +def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor: """ Merges windows to produce higher resolution features. """ + x = tf.shape(windows)[0] + y = tf.cast(height * width / (window_size * window_size), tf.int32) + batch_size = tf.math.floordiv(x, y) windows = tf.reshape( windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) ) @@ -745,7 +748,7 @@ def call( attention_output = attention_outputs[0] attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels)) - shifted_windows = window_reverse(attention_windows, batch_size, window_size, height_pad, width_pad) + shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad) # reverse cyclic shift if shift_size > 0: