Skip to content

Commit c5780ca

Browse files
authored
Merge branch 'main' into try-fix-mypy
2 parents 9011318 + 8f61f4c commit c5780ca

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

test/test_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,12 @@ def test_resize_size_equals_small_edge_size(height, width):
452452

453453

454454
class TestPad:
455-
def test_pad(self):
455+
@pytest.mark.parametrize("fill", [85, 85.0])
456+
def test_pad(self, fill):
456457
height = random.randint(10, 32) * 2
457458
width = random.randint(10, 32) * 2
458459
img = torch.ones(3, height, width, dtype=torch.uint8)
459460
padding = random.randint(1, 20)
460-
fill = random.randint(1, 50)
461461
result = transforms.Compose(
462462
[
463463
transforms.ToPILImage(),
@@ -484,7 +484,7 @@ def test_pad_with_tuple_of_pad_values(self):
484484
output = transforms.Pad(padding)(img)
485485
assert output.size == (width + padding[0] * 2, height + padding[1] * 2)
486486

487-
padding = tuple(random.randint(1, 20) for _ in range(4))
487+
padding = [random.randint(1, 20) for _ in range(4)]
488488
output = transforms.Pad(padding)(img)
489489
assert output.size[0] == width + padding[0] + padding[2]
490490
assert output.size[1] == height + padding[1] + padding[3]

torchvision/transforms/functional_pil.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def pad(
154154

155155
if not isinstance(padding, (numbers.Number, tuple, list)):
156156
raise TypeError("Got inappropriate padding arg")
157-
if not isinstance(fill, (numbers.Number, str, tuple)):
157+
if not isinstance(fill, (numbers.Number, str, tuple, list)):
158158
raise TypeError("Got inappropriate fill arg")
159159
if not isinstance(padding_mode, str):
160160
raise TypeError("Got inappropriate padding_mode arg")
@@ -301,6 +301,12 @@ def _parse_fill(
301301

302302
fill = tuple(fill)
303303

304+
if img.mode != "F":
305+
if isinstance(fill, (list, tuple)):
306+
fill = tuple(int(x) for x in fill)
307+
else:
308+
fill = int(fill)
309+
304310
return {name: fill}
305311

306312

torchvision/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def __init__(self, padding, fill=0, padding_mode="constant"):
428428
if not isinstance(padding, (numbers.Number, tuple, list)):
429429
raise TypeError("Got inappropriate padding arg")
430430

431-
if not isinstance(fill, (numbers.Number, str, tuple)):
431+
if not isinstance(fill, (numbers.Number, str, tuple, list)):
432432
raise TypeError("Got inappropriate fill arg")
433433

434434
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:

0 commit comments

Comments
 (0)