diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 763e25018e5dff..8627220c71aa51 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -1014,8 +1014,9 @@ def input_audio(self): ) def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs): - model1 = class1.from_pretrained(self.repo_id).to(torch_device) - model2 = class2.from_pretrained(self.repo_id).to(torch_device) + # half-precision loading to limit GPU usage + model1 = class1.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device) + model2 = class2.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device) set_seed(0) output_1 = model1.generate(**inputs, **class1_kwargs)