Skip to content

Commit

Permalink
[tests] make 2 tests device-agnostic (#30008)
Browse files Browse the repository at this point in the history
add torch device
  • Loading branch information
faaany authored and ArthurZucker committed Apr 22, 2024
1 parent d91105b commit fbd45ec
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tests/models/blip_2/test_modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def test_inference_t5_multi_accelerator(self):

# prepare image
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16)
inputs = processor(images=image, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)

predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
Expand All @@ -1003,7 +1003,7 @@ def test_inference_t5_multi_accelerator(self):

# image and context
prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)

predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_model_parallelism_gpt2(self):

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
inputs = tokenizer("Hello, my name is", return_tensors="pt")
output = model.generate(inputs["input_ids"].to(0))
output = model.generate(inputs["input_ids"].to(f"{torch_device}:0"))

text_output = tokenizer.decode(output[0].tolist())
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
Expand Down

0 comments on commit fbd45ec

Please sign in to comment.