diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 1de49bf19eebb1..26aecaeb1ad9b7 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index e59a1726b0735d..fcb1f2495aab8a 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -107,7 +107,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index e25a31b3ff4f0d..864db992772772 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)) + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: