Skip to content

Commit

Permalink
BLIP - fix pt-tf equivalence test (#30258)
Browse files Browse the repository at this point in the history
* BLIP - fix pt-tf equivalence test

* Update tests/models/blip/test_modeling_blip.py

* Update more model tests
  • Loading branch information
amyeroberts authored and ydshieh committed Apr 23, 2024
1 parent 34a3f76 commit f292cf1
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/models/blip/test_modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=Tru
self.text_model_tester = BlipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length # need seq_length for pt-tf equivalence test
self.is_training = is_training

def prepare_config_and_inputs(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/models/blip_2/test_modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def __init__(
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
self.text_model_tester = Blip2TextModelDecoderOnlyTester(parent, **text_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
self.is_training = is_training
self.num_query_tokens = num_query_tokens

Expand Down Expand Up @@ -618,6 +619,7 @@ def __init__(
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
self.is_training = is_training
self.num_query_tokens = num_query_tokens

Expand Down
1 change: 1 addition & 0 deletions tests/models/instructblip/test_modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def __init__(
self.qformer_model_tester = InstructBlipQFormerModelTester(parent, **qformer_kwargs)
self.text_model_tester = InstructBlipTextModelDecoderOnlyTester(parent, **text_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
self.is_training = is_training
self.num_query_tokens = num_query_tokens

Expand Down
1 change: 1 addition & 0 deletions tests/models/pix2struct/test_modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=Tru
self.text_model_tester = Pix2StructTextModelTester(parent, **text_kwargs)
self.vision_model_tester = Pix2StructVisionModelTester(parent, **vision_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
self.is_training = is_training

def prepare_config_and_inputs(self):
Expand Down

0 comments on commit f292cf1

Please sign in to comment.