diff --git a/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py b/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py index 0f67b05b..1ce76aae 100644 --- a/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py +++ b/keras_cv_attention_models/swin_transformer_v2/swin_transformer_v2.py @@ -89,7 +89,9 @@ def window_multi_head_self_attention(inputs, filters=-1, num_heads=4, meta_hidde if mask is not None: query_blocks = attn.shape[2] attn = tf.reshape(attn, [-1, mask.shape[0], num_heads, query_blocks, query_blocks]) - attn += tf.expand_dims(tf.expand_dims(mask, 1), 0) # expand dims on batch and num_heads + # attn += tf.expand_dims(tf.expand_dims(mask, 1), 0) # expand dims on batch and num_heads + mask = tf.expand_dims(tf.expand_dims(mask, 1), 0) # expand dims on batch and num_heads + attn = keras.layers.Add()([attn, mask]) attn = tf.reshape(attn, [-1, num_heads, query_blocks, query_blocks]) attention_scores = keras.layers.Softmax(axis=-1, name=name and name + "attention_scores")(attn) @@ -107,7 +109,7 @@ def window_multi_head_self_attention(inputs, filters=-1, num_heads=4, meta_hidde def make_window_attention_mask(height, width, window_height, window_width, shift_height, shift_width): - float_dtype = tf.keras.mixed_precision.global_policy().compute_dtype + # float_dtype = tf.keras.mixed_precision.global_policy().compute_dtype hh_split = [0, height - window_height, height - shift_height, height] ww_split = [0, width - window_width, width - shift_width, width] mask_value, total_ww, mask = 0, len(ww_split) - 1, [] @@ -123,7 +125,7 @@ def make_window_attention_mask(height, width, window_height, window_width, shift mask = tf.transpose(mask, [0, 2, 1, 3]) mask = tf.reshape(mask, [-1, window_height * window_width]) attn_mask = tf.expand_dims(mask, 1) - tf.expand_dims(mask, 2) - return tf.cast(tf.where(attn_mask != 0, -100, 0), float_dtype) + return tf.cast(tf.where(attn_mask != 0, -100, 0), "float32") def shifted_window_attention(inputs, window_size, num_heads=4, shift_size=0, name=""):