Skip to content

Commit eb913e2

Browse files
committed
more test cases.
1 parent 962483b commit eb913e2

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,13 +1084,18 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
10841084
assert loaded_model
10851085
assert new_output.sample.shape == (4, 4, 16, 16)
10861086

1087-
def test_wrong_device_map_raises_error(self):
1087+
@parameterized.expand(
1088+
[
1089+
(-1, "You can't pass device_map as a negative int"),
1090+
("foo", "When passing device_map as a string, the value needs to be a device name"),
1091+
]
1092+
)
1093+
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
10881094
with self.assertRaises(ValueError) as err_ctx:
10891095
_ = self.model_class.from_pretrained(
1090-
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=-1
1096+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
10911097
)
10921098

1093-
msg_substring = "You can't pass device_map as a negative int"
10941099
assert msg_substring in str(err_ctx.exception)
10951100

10961101
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])

0 commit comments

Comments
 (0)