From b9d717f104ee20a5836d032567433653509f6d55 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Mon, 11 Dec 2023 08:18:41 +0000 Subject: [PATCH] Fix `SeamlessM4Tv2ModelIntegrationTest` (#27911) change dtype of some integration tests --- .../models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 763e25018e5dff..8627220c71aa51 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -1014,8 +1014,9 @@ def input_audio(self): ) def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs): - model1 = class1.from_pretrained(self.repo_id).to(torch_device) - model2 = class2.from_pretrained(self.repo_id).to(torch_device) + # half-precision loading to limit GPU usage + model1 = class1.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device) + model2 = class2.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device) set_seed(0) output_1 = model1.generate(**inputs, **class1_kwargs)