Skip to content

Commit

Permalink
Add some attributes we're going to need later
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 4, 2023
1 parent c4504a0 commit 69add2c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ def call(self, x: tf.Tensor, training=None):
class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer):
"""Construct the overlapping patch embeddings."""

def __init__(self, patch_size, stride, hidden_size, **kwargs):
def __init__(self, patch_size, stride, num_channels, hidden_size, **kwargs):
super().__init__(**kwargs)
self.padding = tf.keras.layers.ZeroPadding2D(padding=patch_size // 2)
self.proj = tf.keras.layers.Conv2D(
filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj"
)

self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
self.num_channels = num_channels

def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]:
embeddings = self.proj(self.padding(pixel_values))
Expand Down Expand Up @@ -363,6 +364,7 @@ def __init__(self, config: SegformerConfig, **kwargs):
TFSegformerOverlapPatchEmbeddings(
patch_size=config.patch_sizes[i],
stride=config.strides[i],
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
hidden_size=config.hidden_sizes[i],
name=f"patch_embeddings.{i}",
)
Expand Down

0 comments on commit 69add2c

Please sign in to comment.