Skip to content

Commit

Permalink
Fix other models affected by test change
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed May 2, 2023
1 parent 08273d1 commit 38f3d35
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
3 changes: 3 additions & 0 deletions tests/models/data2vec/test_modeling_tf_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def test_model_from_pretrained(self):
model = TFData2VecVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)


# We will verify our results on an image of cute cats
def prepare_img():
Expand Down
3 changes: 3 additions & 0 deletions tests/models/sam/test_modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ def test_retain_grad_hidden_states_attentions(self):
def test_hidden_states_output(self):
pass

def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)

@slow
def test_model_from_pretrained(self):
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down
4 changes: 2 additions & 2 deletions tests/models/vit_mae/test_modeling_tf_vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def prepare_numpy_arrays(inputs_dict):

# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5):
# make masks reproducible
np.random.seed(2)

Expand All @@ -279,7 +279,7 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
# PT inputs will be prepared in `super().check_pt_tf_models()` with this added `noise` argument
tf_inputs_dict["noise"] = tf_noise

super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict, tol=tol)

# overwrite from common since TFViTMAEForPretraining outputs loss along with
# logits and mask indices. loss and mask indices are not suitable for integration
Expand Down

0 comments on commit 38f3d35

Please sign in to comment.