diff --git a/src/transformers/models/segformer/modeling_tf_segformer.py b/src/transformers/models/segformer/modeling_tf_segformer.py index 83a6b9a7530f11..86eb3249382996 100644 --- a/src/transformers/models/segformer/modeling_tf_segformer.py +++ b/src/transformers/models/segformer/modeling_tf_segformer.py @@ -79,7 +79,7 @@ def call(self, x: tf.Tensor, training=None): class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer): """Construct the overlapping patch embeddings.""" - def __init__(self, patch_size, stride, hidden_size, **kwargs): + def __init__(self, patch_size, stride, num_channels, hidden_size, **kwargs): super().__init__(**kwargs) self.padding = tf.keras.layers.ZeroPadding2D(padding=patch_size // 2) self.proj = tf.keras.layers.Conv2D( @@ -87,6 +87,7 @@ def __init__(self, patch_size, stride, hidden_size, **kwargs): ) self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm") + self.num_channels = num_channels def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]: embeddings = self.proj(self.padding(pixel_values)) @@ -363,6 +364,7 @@ def __init__(self, config: SegformerConfig, **kwargs): TFSegformerOverlapPatchEmbeddings( patch_size=config.patch_sizes[i], stride=config.strides[i], + num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], hidden_size=config.hidden_sizes[i], name=f"patch_embeddings.{i}", )