diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 4b61f7140cee13..8bbd8587b683f4 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -163,7 +163,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision if v.dtype == bfloat16: v = v.float() - pt_state_dict[k] = v.numpy() + pt_state_dict[k] = v.cpu().numpy() model_prefix = flax_model.base_model_prefix diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 3b6994428088a2..bd9664dd15fd32 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -848,6 +848,7 @@ def test_equivalence_pt_to_flax(self): with self.subTest(model_class.__name__): # load PyTorch class pt_model = model_class(config).eval() + pt_model.to(torch_device) # Flax models don't use the `use_cache` option and cache is not returned as a default. # So we disable `use_cache` here for PyTorch model. pt_model.config.use_cache = False @@ -881,7 +882,7 @@ def test_equivalence_pt_to_flax(self): fx_outputs = fx_model(**fx_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) @@ -892,7 +893,7 @@ def test_equivalence_pt_to_flax(self): len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2) # overwrite from common since FlaxCLIPModel returns nested output # which is not supported in the common test @@ -921,6 +922,7 @@ def test_equivalence_flax_to_pt(self): fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + pt_model.to(torch_device) # make sure weights are tied in PyTorch pt_model.tie_weights() @@ -940,11 +942,12 @@ def test_equivalence_flax_to_pt(self): self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + pt_model_loaded.to(torch_device) with torch.no_grad(): pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() @@ -953,7 +956,7 @@ def test_equivalence_flax_to_pt(self): len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" ) for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) @slow def test_model_from_pretrained(self): diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index c8f76a144be703..35434a280e9ae0 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -297,7 +297,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): # prepare inputs flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() @@ -305,7 +305,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs = fx_model(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: @@ -315,7 +315,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: @@ -330,7 +330,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5) def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index e4e86fb6952777..10cb2b71824e99 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -170,7 +170,7 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict): embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model - ) + ).to(torch_device) self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight)) self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight)) diff --git a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 62ce0d660a0abc..8f210a07d278a6 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -412,7 +412,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): # prepare inputs flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() @@ -420,7 +420,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs = fx_model(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: @@ -430,7 +430,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: @@ -445,7 +445,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5) def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) diff --git a/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py index 98c3a275825b0b..fabef4b8c6de04 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py @@ -241,7 +241,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): # prepare inputs flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() @@ -249,7 +249,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs = fx_model(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: @@ -259,7 +259,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: @@ -274,7 +274,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5) def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) diff --git a/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py index e57a4bd4db8e05..e1e8eb4076c137 100644 --- a/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py +++ b/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py @@ -160,7 +160,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): # prepare inputs flax_inputs = inputs_dict - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() @@ -168,7 +168,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs = fx_model(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: @@ -178,7 +178,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: @@ -193,7 +193,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2) + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2) def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) diff --git a/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py index e30d0ff01bfdc9..d935c0d27d1ccb 100644 --- a/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py +++ b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py @@ -179,7 +179,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas # prepare inputs inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} pt_inputs = inputs_dict - flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()} + flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() @@ -187,7 +187,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas fx_outputs = fx_model(**flax_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: @@ -197,7 +197,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: @@ -212,7 +212,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): - self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2) + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2) def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)