diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index 3cb702b4f6a4bd..6ecec70623b280 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -254,7 +254,7 @@ def random_masking(self, sequence: tf.Tensor, noise: Optional[tf.Tensor] = None) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] - sequence_masked = tf.gather( + sequence_unmasked = tf.gather( sequence, axis=1, batch_dims=1, @@ -271,7 +271,7 @@ def random_masking(self, sequence: tf.Tensor, noise: Optional[tf.Tensor] = None) # unshuffle to get the binary mask mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore) - return sequence_masked, mask, ids_restore + return sequence_unmasked, mask, ids_restore def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor: embeddings = self.patch_embeddings(pixel_values) diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 0667bdd73c5545..1031b7088774f4 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -251,7 +251,7 @@ def random_masking(self, sequence, noise=None): # keep the first subset ids_keep = ids_shuffle[:, :len_keep] - sequence_masked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) + sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([batch_size, seq_length], device=sequence.device) @@ -259,7 +259,7 @@ def random_masking(self, sequence, noise=None): # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) - return sequence_masked, mask, ids_restore + return sequence_unmasked, mask, ids_restore def forward(self, pixel_values, noise=None): batch_size, num_channels, height, width = pixel_values.shape