Skip to content

Commit f169b52

Browse files
authored
Fix Tensor.Reshape with wildcard (#120801)
Fixes #120343. Also added unit tests for `Tensor.Reshape` for `TensorSpan` and `ReadOnlyTensorSpan`
1 parent 3e7382e commit f169b52

File tree

4 files changed

+288
-1
lines changed

4 files changed

+288
-1
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
13821382
else
13831383
strides = [];
13841384

1385-
return new Tensor<T>(tensor._values, tensor._start, lengths, strides);
1385+
return new Tensor<T>(tensor._values, tensor._start, newLengths, strides);
13861386
}
13871387

13881388
/// <summary>

src/libraries/System.Numerics.Tensors/tests/ReadOnlyTensorSpanTests.cs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,5 +1437,125 @@ public static void ToStringZeroDataTest()
14371437
""";
14381438
Assert.Equal(expected, tensor.ToString([2, 0, 2]));
14391439
}
1440+
1441+
[Fact]
1442+
public static void TensorReshapeTest()
1443+
{
1444+
int[] a = [1, 2, 3, 4, 5, 6, 7, 8, 9];
1445+
nint[] dims = [9];
1446+
var origTensor = Tensor.CreateFromShapeUninitialized<int>(dims.AsSpan(), false);
1447+
var span = a.AsTensorSpan(dims);
1448+
span.CopyTo(origTensor);
1449+
var tensor = origTensor.AsReadOnlyTensorSpan();
1450+
1451+
Assert.Equal(1, tensor.Rank);
1452+
Assert.Equal(9, tensor.Lengths[0]);
1453+
Assert.Equal(1, tensor.Strides.Length);
1454+
Assert.Equal(1, tensor.Strides[0]);
1455+
Assert.Equal(1, tensor[0]);
1456+
Assert.Equal(2, tensor[1]);
1457+
Assert.Equal(3, tensor[2]);
1458+
Assert.Equal(4, tensor[3]);
1459+
Assert.Equal(5, tensor[4]);
1460+
Assert.Equal(6, tensor[5]);
1461+
Assert.Equal(7, tensor[6]);
1462+
Assert.Equal(8, tensor[7]);
1463+
Assert.Equal(9, tensor[8]);
1464+
1465+
dims = [3, 3];
1466+
tensor = Tensor.Reshape(tensor, dims);
1467+
Assert.Equal(2, tensor.Rank);
1468+
Assert.Equal(3, tensor.Lengths[0]);
1469+
Assert.Equal(3, tensor.Lengths[1]);
1470+
Assert.Equal(2, tensor.Strides.Length);
1471+
Assert.Equal(3, tensor.Strides[0]);
1472+
Assert.Equal(1, tensor.Strides[1]);
1473+
Assert.Equal(1, tensor[0, 0]);
1474+
Assert.Equal(2, tensor[0, 1]);
1475+
Assert.Equal(3, tensor[0, 2]);
1476+
Assert.Equal(4, tensor[1, 0]);
1477+
Assert.Equal(5, tensor[1, 1]);
1478+
Assert.Equal(6, tensor[1, 2]);
1479+
Assert.Equal(7, tensor[2, 0]);
1480+
Assert.Equal(8, tensor[2, 1]);
1481+
Assert.Equal(9, tensor[2, 2]);
1482+
1483+
dims = [-1];
1484+
tensor = Tensor.Reshape(tensor, dims);
1485+
Assert.Equal(1, tensor.Rank);
1486+
Assert.Equal(9, tensor.Lengths[0]);
1487+
Assert.Equal(1, tensor.Strides.Length);
1488+
Assert.Equal(1, tensor.Strides[0]);
1489+
Assert.Equal(1, tensor[0]);
1490+
Assert.Equal(2, tensor[1]);
1491+
Assert.Equal(3, tensor[2]);
1492+
Assert.Equal(4, tensor[3]);
1493+
Assert.Equal(5, tensor[4]);
1494+
Assert.Equal(6, tensor[5]);
1495+
Assert.Equal(7, tensor[6]);
1496+
Assert.Equal(8, tensor[7]);
1497+
Assert.Equal(9, tensor[8]);
1498+
1499+
dims = [3, -1];
1500+
tensor = Tensor.Reshape(tensor, dims);
1501+
Assert.Equal(2, tensor.Rank);
1502+
Assert.Equal(3, tensor.Lengths[0]);
1503+
Assert.Equal(3, tensor.Lengths[1]);
1504+
Assert.Equal(2, tensor.Strides.Length);
1505+
Assert.Equal(3, tensor.Strides[0]);
1506+
Assert.Equal(1, tensor.Strides[1]);
1507+
Assert.Equal(1, tensor[0, 0]);
1508+
Assert.Equal(2, tensor[0, 1]);
1509+
Assert.Equal(3, tensor[0, 2]);
1510+
Assert.Equal(4, tensor[1, 0]);
1511+
Assert.Equal(5, tensor[1, 1]);
1512+
Assert.Equal(6, tensor[1, 2]);
1513+
Assert.Equal(7, tensor[2, 0]);
1514+
Assert.Equal(8, tensor[2, 1]);
1515+
Assert.Equal(9, tensor[2, 2]);
1516+
1517+
Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsReadOnlyTensorSpan(), [-1, -1]));
1518+
1519+
Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsReadOnlyTensorSpan(), [1, 2, 3, 4, 5]));
1520+
1521+
// Make sure reshape works correctly with 0 strides.
1522+
origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
1523+
tensor = origTensor.AsReadOnlyTensorSpan();
1524+
tensor = Tensor.Reshape(tensor, [1, 2]);
1525+
Assert.Equal(2, tensor.Rank);
1526+
Assert.Equal(1, tensor.Lengths[0]);
1527+
Assert.Equal(2, tensor.Lengths[1]);
1528+
Assert.Equal(0, tensor.Strides[0]);
1529+
Assert.Equal(0, tensor.Strides[1]);
1530+
1531+
origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
1532+
tensor = origTensor.AsReadOnlyTensorSpan();
1533+
tensor = Tensor.Reshape(tensor, [2, 1]);
1534+
Assert.Equal(2, tensor.Rank);
1535+
Assert.Equal(2, tensor.Lengths[0]);
1536+
Assert.Equal(1, tensor.Lengths[1]);
1537+
Assert.Equal(0, tensor.Strides[0]);
1538+
Assert.Equal(0, tensor.Strides[1]);
1539+
1540+
tensor = Tensor.Reshape(tensor, [1, 2, 1]);
1541+
Assert.Equal(3, tensor.Rank);
1542+
Assert.Equal(1, tensor.Lengths[0]);
1543+
Assert.Equal(2, tensor.Lengths[1]);
1544+
Assert.Equal(1, tensor.Lengths[2]);
1545+
Assert.Equal(0, tensor.Strides[0]);
1546+
Assert.Equal(0, tensor.Strides[1]);
1547+
Assert.Equal(0, tensor.Strides[2]);
1548+
1549+
tensor = Tensor.Reshape(tensor, [1, 1, -1, 1]);
1550+
Assert.Equal(4, tensor.Rank);
1551+
Assert.Equal(1, tensor.Lengths[0]);
1552+
Assert.Equal(1, tensor.Lengths[1]);
1553+
Assert.Equal(2, tensor.Lengths[2]);
1554+
Assert.Equal(1, tensor.Lengths[3]);
1555+
Assert.Equal(0, tensor.Strides[0]);
1556+
Assert.Equal(0, tensor.Strides[1]);
1557+
Assert.Equal(0, tensor.Strides[2]);
1558+
Assert.Equal(0, tensor.Strides[3]);
1559+
}
14401560
}
14411561
}

src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,5 +2077,125 @@ public static void ToStringZeroDataTest()
20772077
""";
20782078
Assert.Equal(expected, tensor.ToString([2, 0, 2]));
20792079
}
2080+
2081+
[Fact]
2082+
public static void TensorReshapeTest()
2083+
{
2084+
int[] a = [1, 2, 3, 4, 5, 6, 7, 8, 9];
2085+
nint[] dims = [9];
2086+
var origTensor = Tensor.CreateFromShapeUninitialized<int>(dims.AsSpan(), false);
2087+
var span = a.AsTensorSpan(dims);
2088+
span.CopyTo(origTensor);
2089+
var tensor = origTensor.AsTensorSpan();
2090+
2091+
Assert.Equal(1, tensor.Rank);
2092+
Assert.Equal(9, tensor.Lengths[0]);
2093+
Assert.Equal(1, tensor.Strides.Length);
2094+
Assert.Equal(1, tensor.Strides[0]);
2095+
Assert.Equal(1, tensor[0]);
2096+
Assert.Equal(2, tensor[1]);
2097+
Assert.Equal(3, tensor[2]);
2098+
Assert.Equal(4, tensor[3]);
2099+
Assert.Equal(5, tensor[4]);
2100+
Assert.Equal(6, tensor[5]);
2101+
Assert.Equal(7, tensor[6]);
2102+
Assert.Equal(8, tensor[7]);
2103+
Assert.Equal(9, tensor[8]);
2104+
2105+
dims = [3, 3];
2106+
tensor = Tensor.Reshape(tensor, dims);
2107+
Assert.Equal(2, tensor.Rank);
2108+
Assert.Equal(3, tensor.Lengths[0]);
2109+
Assert.Equal(3, tensor.Lengths[1]);
2110+
Assert.Equal(2, tensor.Strides.Length);
2111+
Assert.Equal(3, tensor.Strides[0]);
2112+
Assert.Equal(1, tensor.Strides[1]);
2113+
Assert.Equal(1, tensor[0, 0]);
2114+
Assert.Equal(2, tensor[0, 1]);
2115+
Assert.Equal(3, tensor[0, 2]);
2116+
Assert.Equal(4, tensor[1, 0]);
2117+
Assert.Equal(5, tensor[1, 1]);
2118+
Assert.Equal(6, tensor[1, 2]);
2119+
Assert.Equal(7, tensor[2, 0]);
2120+
Assert.Equal(8, tensor[2, 1]);
2121+
Assert.Equal(9, tensor[2, 2]);
2122+
2123+
dims = [-1];
2124+
tensor = Tensor.Reshape(tensor, dims);
2125+
Assert.Equal(1, tensor.Rank);
2126+
Assert.Equal(9, tensor.Lengths[0]);
2127+
Assert.Equal(1, tensor.Strides.Length);
2128+
Assert.Equal(1, tensor.Strides[0]);
2129+
Assert.Equal(1, tensor[0]);
2130+
Assert.Equal(2, tensor[1]);
2131+
Assert.Equal(3, tensor[2]);
2132+
Assert.Equal(4, tensor[3]);
2133+
Assert.Equal(5, tensor[4]);
2134+
Assert.Equal(6, tensor[5]);
2135+
Assert.Equal(7, tensor[6]);
2136+
Assert.Equal(8, tensor[7]);
2137+
Assert.Equal(9, tensor[8]);
2138+
2139+
dims = [3, -1];
2140+
tensor = Tensor.Reshape(tensor, dims);
2141+
Assert.Equal(2, tensor.Rank);
2142+
Assert.Equal(3, tensor.Lengths[0]);
2143+
Assert.Equal(3, tensor.Lengths[1]);
2144+
Assert.Equal(2, tensor.Strides.Length);
2145+
Assert.Equal(3, tensor.Strides[0]);
2146+
Assert.Equal(1, tensor.Strides[1]);
2147+
Assert.Equal(1, tensor[0, 0]);
2148+
Assert.Equal(2, tensor[0, 1]);
2149+
Assert.Equal(3, tensor[0, 2]);
2150+
Assert.Equal(4, tensor[1, 0]);
2151+
Assert.Equal(5, tensor[1, 1]);
2152+
Assert.Equal(6, tensor[1, 2]);
2153+
Assert.Equal(7, tensor[2, 0]);
2154+
Assert.Equal(8, tensor[2, 1]);
2155+
Assert.Equal(9, tensor[2, 2]);
2156+
2157+
Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsTensorSpan(), [-1, -1]));
2158+
2159+
Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsTensorSpan(), [1, 2, 3, 4, 5]));
2160+
2161+
// Make sure reshape works correctly with 0 strides.
2162+
origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
2163+
tensor = origTensor.AsTensorSpan();
2164+
tensor = Tensor.Reshape(tensor, [1, 2]);
2165+
Assert.Equal(2, tensor.Rank);
2166+
Assert.Equal(1, tensor.Lengths[0]);
2167+
Assert.Equal(2, tensor.Lengths[1]);
2168+
Assert.Equal(0, tensor.Strides[0]);
2169+
Assert.Equal(0, tensor.Strides[1]);
2170+
2171+
origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
2172+
tensor = origTensor.AsTensorSpan();
2173+
tensor = Tensor.Reshape(tensor, [2, 1]);
2174+
Assert.Equal(2, tensor.Rank);
2175+
Assert.Equal(2, tensor.Lengths[0]);
2176+
Assert.Equal(1, tensor.Lengths[1]);
2177+
Assert.Equal(0, tensor.Strides[0]);
2178+
Assert.Equal(0, tensor.Strides[1]);
2179+
2180+
tensor = Tensor.Reshape(tensor, [1, 2, 1]);
2181+
Assert.Equal(3, tensor.Rank);
2182+
Assert.Equal(1, tensor.Lengths[0]);
2183+
Assert.Equal(2, tensor.Lengths[1]);
2184+
Assert.Equal(1, tensor.Lengths[2]);
2185+
Assert.Equal(0, tensor.Strides[0]);
2186+
Assert.Equal(0, tensor.Strides[1]);
2187+
Assert.Equal(0, tensor.Strides[2]);
2188+
2189+
tensor = Tensor.Reshape(tensor, [1, 1, -1, 1]);
2190+
Assert.Equal(4, tensor.Rank);
2191+
Assert.Equal(1, tensor.Lengths[0]);
2192+
Assert.Equal(1, tensor.Lengths[1]);
2193+
Assert.Equal(2, tensor.Lengths[2]);
2194+
Assert.Equal(1, tensor.Lengths[3]);
2195+
Assert.Equal(0, tensor.Strides[0]);
2196+
Assert.Equal(0, tensor.Strides[1]);
2197+
Assert.Equal(0, tensor.Strides[2]);
2198+
Assert.Equal(0, tensor.Strides[3]);
2199+
}
20802200
}
20812201
}

src/libraries/System.Numerics.Tensors/tests/TensorTests.cs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,6 +2219,42 @@ public static void TensorReshapeTest()
22192219
Assert.Equal(8, tensor[2, 1]);
22202220
Assert.Equal(9, tensor[2, 2]);
22212221

2222+
dims = [-1];
2223+
tensor = Tensor.Reshape(tensor, dims);
2224+
Assert.Equal(1, tensor.Rank);
2225+
Assert.Equal(9, tensor.Lengths[0]);
2226+
Assert.Equal(1, tensor.Strides.Length);
2227+
Assert.Equal(1, tensor.Strides[0]);
2228+
Assert.Equal(1, tensor[0]);
2229+
Assert.Equal(2, tensor[1]);
2230+
Assert.Equal(3, tensor[2]);
2231+
Assert.Equal(4, tensor[3]);
2232+
Assert.Equal(5, tensor[4]);
2233+
Assert.Equal(6, tensor[5]);
2234+
Assert.Equal(7, tensor[6]);
2235+
Assert.Equal(8, tensor[7]);
2236+
Assert.Equal(9, tensor[8]);
2237+
2238+
dims = [3, -1];
2239+
tensor = Tensor.Reshape(tensor, dims);
2240+
Assert.Equal(2, tensor.Rank);
2241+
Assert.Equal(3, tensor.Lengths[0]);
2242+
Assert.Equal(3, tensor.Lengths[1]);
2243+
Assert.Equal(2, tensor.Strides.Length);
2244+
Assert.Equal(3, tensor.Strides[0]);
2245+
Assert.Equal(1, tensor.Strides[1]);
2246+
Assert.Equal(1, tensor[0, 0]);
2247+
Assert.Equal(2, tensor[0, 1]);
2248+
Assert.Equal(3, tensor[0, 2]);
2249+
Assert.Equal(4, tensor[1, 0]);
2250+
Assert.Equal(5, tensor[1, 1]);
2251+
Assert.Equal(6, tensor[1, 2]);
2252+
Assert.Equal(7, tensor[2, 0]);
2253+
Assert.Equal(8, tensor[2, 1]);
2254+
Assert.Equal(9, tensor[2, 2]);
2255+
2256+
Assert.Throws<ArgumentException>(() => Tensor.Reshape(tensor, [-1, -1]));
2257+
22222258
Assert.Throws<ArgumentException>(() => Tensor.Reshape(tensor, [1, 2, 3, 4, 5]));
22232259

22242260
// Make sure reshape works correctly with 0 strides.
@@ -2246,6 +2282,17 @@ public static void TensorReshapeTest()
22462282
Assert.Equal(0, tensor.Strides[0]);
22472283
Assert.Equal(0, tensor.Strides[1]);
22482284
Assert.Equal(0, tensor.Strides[2]);
2285+
2286+
tensor = Tensor.Reshape(tensor, [1, 1, -1, 1]);
2287+
Assert.Equal(4, tensor.Rank);
2288+
Assert.Equal(1, tensor.Lengths[0]);
2289+
Assert.Equal(1, tensor.Lengths[1]);
2290+
Assert.Equal(2, tensor.Lengths[2]);
2291+
Assert.Equal(1, tensor.Lengths[3]);
2292+
Assert.Equal(0, tensor.Strides[0]);
2293+
Assert.Equal(0, tensor.Strides[1]);
2294+
Assert.Equal(0, tensor.Strides[2]);
2295+
Assert.Equal(0, tensor.Strides[3]);
22492296
}
22502297

22512298
[Fact]

0 commit comments

Comments
 (0)