Skip to content

Commit

Permalink
Refactor TFSwinLayer to increase serving compatibility (#18352)
Browse files Browse the repository at this point in the history
* Refactor `TFSwinLayer` to increase serving compatibility

Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix missed parameters while refactoring

Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix window_reverse to calculate batch size

Signed-off-by: Seunghwan Hong <harrydrippin@gmail.com>
Co-Authored-By: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
harrydrippin and amyeroberts authored Aug 5, 2022
1 parent 575aa6e commit bf174f9
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/transformers/models/swin/modeling_tf_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
"""
Merges windows to produce higher resolution features.
"""
x = shape_list(windows)[0]
x = tf.shape(windows)[0]
y = tf.cast(height * width / (window_size * window_size), tf.int32)
batch_size = int(x / y)
batch_size = tf.math.floordiv(x, y)
windows = tf.reshape(
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
)
Expand Down Expand Up @@ -695,16 +695,18 @@ def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: i
img_mask = tf.expand_dims(img_mask, -1)
img_mask = tf.expand_dims(img_mask, 0)

mask_windows = window_partition(img_mask, self.window_size)
mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size))
mask_windows = window_partition(img_mask, window_size)
mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size))
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
return attn_mask

def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]:
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
def maybe_pad(
self, hidden_states: tf.Tensor, window_size: int, height: int, width: int
) -> Tuple[tf.Tensor, tf.Tensor]:
pad_right = (window_size - width % window_size) % window_size
pad_bottom = (window_size - height % window_size) % window_size
pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]]
hidden_states = tf.pad(hidden_states, pad_values)
pad_values = tf.reshape(pad_values, (-1,))
Expand All @@ -730,7 +732,7 @@ def call(
hidden_states = self.layernorm_before(hidden_states, training=training)
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels))
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width)

_, height_pad, width_pad, _ = shape_list(hidden_states)
# cyclic shift
Expand Down

0 comments on commit bf174f9

Please sign in to comment.