Skip to content

Commit cabca39

Browse files
dimartfmassa
authored andcommitted
Fix make_grid: support any number of channels in tensor (#1300)
1 parent 7f7e766 commit cabca39

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchvision/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def norm_range(t, range):
7474
xmaps = min(nrow, nmaps)
7575
ymaps = int(math.ceil(float(nmaps) / xmaps))
7676
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
77-
grid = tensor.new_full((3, height * ymaps + padding, width * xmaps + padding), pad_value)
77+
num_channels = tensor.size(1)
78+
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
7879
k = 0
7980
for y in irange(ymaps):
8081
for x in irange(xmaps):

0 commit comments

Comments
 (0)