Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 2c53dac

Browse files
Fix conv2d channels_last issue for CPU backend (#1394)
For the torchinductor CPU path, there set a wrong layout format for the conv2d channels_last path, which will assert a stride mismatch error: ```python def call(primals_1, primals_2, primals_3): primals_1_size = primals_1.size() s0 = primals_1_size[0] primals_3_size = primals_3.size() s1 = primals_3_size[0] s2 = primals_3_size[2] buf0 = aten.convolution(primals_3, primals_1, None, (1, 1), (0, 0), (1, 1), False, (0, 0), 1) assert buf0.size() == (s1, 3, 16, 16) > assert buf0.stride() == (768, 256, 16, 1) E AssertionError ``` This PR fixed this issue.
1 parent 9263063 commit 2c53dac

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
pass
6969

7070
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
71+
7172
torchinductor.config.triton.autotune = False # too slow
7273

7374

@@ -1389,6 +1390,50 @@ def fn(x, w, b):
13891390
check_lowp=False,
13901391
)
13911392

1393+
@unittest.skipIf(HAS_CUDA, "only support cpu channels_last")
1394+
def test_conv2d_channels_last(self):
1395+
m = torch.nn.Sequential(
1396+
torch.nn.Conv2d(3, 3, 1, 1),
1397+
ToTuple(),
1398+
)
1399+
# only weight is channels_last
1400+
self.common(
1401+
m.to(memory_format=torch.channels_last),
1402+
(torch.randn([2, 3, 16, 16]),),
1403+
)
1404+
# only activation is channels_last
1405+
self.common(
1406+
m,
1407+
(torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),),
1408+
)
1409+
# activation and weight are all channels_last
1410+
self.common(
1411+
m.to(memory_format=torch.channels_last),
1412+
(torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),),
1413+
)
1414+
1415+
@unittest.skipIf(HAS_CUDA, "only support cpu channels_last")
1416+
def test_conv3d_channels_last(self):
1417+
m = torch.nn.Sequential(
1418+
torch.nn.Conv3d(3, 3, 1, 1),
1419+
ToTuple(),
1420+
)
1421+
# only weight is channels_last
1422+
self.common(
1423+
m.to(memory_format=torch.channels_last_3d),
1424+
(torch.randn([2, 3, 16, 16, 16]),),
1425+
)
1426+
# only activation is channels_last
1427+
self.common(
1428+
m,
1429+
(torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
1430+
)
1431+
# activation and weight are all channels_last
1432+
self.common(
1433+
m.to(memory_format=torch.channels_last_3d),
1434+
(torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
1435+
)
1436+
13921437
def test_adaptive_avg_pool2d1(self):
13931438
def fn(x):
13941439
return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d(

torchinductor/ir.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,12 @@ def is_stride_ordered(self, order):
15181518
return False
15191519
return True
15201520

1521+
def is_channels_last_stride_ordered(self):
1522+
# create channels_last order(NCHW, NCDHW, the C is the first order).
1523+
order = [0] + list(reversed(range(1, len(self.stride) - 1)))
1524+
order = [len(order)] + order
1525+
return self.is_stride_ordered(order)
1526+
15211527
def as_fixed(self):
15221528
return FixedLayout(
15231529
self.device,
@@ -3104,6 +3110,20 @@ def create(
31043110
)
31053111
else:
31063112
output_layout_str = "torch.contiguous_format"
3113+
# If x or weight have one channels_last(2d or 3d) format, it will call channels_last path,
3114+
# which align with aten.convolutuion path(cpu only support 2d case now).
3115+
# TODO: after cpu 3d convolution support channels_last path, the size check can be removed.
3116+
# TODO: the gpu channels_last path depend on cudnn version, see
3117+
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvUtils.h.
3118+
if (
3119+
x.get_device().type == "cpu"
3120+
and len(x.get_size()) == 4
3121+
and (
3122+
x.get_layout().is_channels_last_stride_ordered()
3123+
or weight.get_layout().is_channels_last_stride_ordered()
3124+
)
3125+
):
3126+
output_layout_str = "torch.channels_last"
31073127

31083128
if output_layout_str == "torch.channels_last":
31093129
stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))

0 commit comments

Comments
 (0)