Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
else
strides = [];

return new Tensor<T>(tensor._values, tensor._start, lengths, strides);
return new Tensor<T>(tensor._values, tensor._start, newLengths, strides);
}

/// <summary>
Expand Down
120 changes: 120 additions & 0 deletions src/libraries/System.Numerics.Tensors/tests/ReadOnlyTensorSpanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1437,5 +1437,125 @@ public static void ToStringZeroDataTest()
""";
Assert.Equal(expected, tensor.ToString([2, 0, 2]));
}

[Fact]
public static void TensorReshapeTest()
{
int[] a = [1, 2, 3, 4, 5, 6, 7, 8, 9];
nint[] dims = [9];
var origTensor = Tensor.CreateFromShapeUninitialized<int>(dims.AsSpan(), false);
var span = a.AsTensorSpan(dims);
span.CopyTo(origTensor);
var tensor = origTensor.AsReadOnlyTensorSpan();

Assert.Equal(1, tensor.Rank);
Assert.Equal(9, tensor.Lengths[0]);
Assert.Equal(1, tensor.Strides.Length);
Assert.Equal(1, tensor.Strides[0]);
Assert.Equal(1, tensor[0]);
Assert.Equal(2, tensor[1]);
Assert.Equal(3, tensor[2]);
Assert.Equal(4, tensor[3]);
Assert.Equal(5, tensor[4]);
Assert.Equal(6, tensor[5]);
Assert.Equal(7, tensor[6]);
Assert.Equal(8, tensor[7]);
Assert.Equal(9, tensor[8]);

dims = [3, 3];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(2, tensor.Rank);
Assert.Equal(3, tensor.Lengths[0]);
Assert.Equal(3, tensor.Lengths[1]);
Assert.Equal(2, tensor.Strides.Length);
Assert.Equal(3, tensor.Strides[0]);
Assert.Equal(1, tensor.Strides[1]);
Assert.Equal(1, tensor[0, 0]);
Assert.Equal(2, tensor[0, 1]);
Assert.Equal(3, tensor[0, 2]);
Assert.Equal(4, tensor[1, 0]);
Assert.Equal(5, tensor[1, 1]);
Assert.Equal(6, tensor[1, 2]);
Assert.Equal(7, tensor[2, 0]);
Assert.Equal(8, tensor[2, 1]);
Assert.Equal(9, tensor[2, 2]);

dims = [-1];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(1, tensor.Rank);
Assert.Equal(9, tensor.Lengths[0]);
Assert.Equal(1, tensor.Strides.Length);
Assert.Equal(1, tensor.Strides[0]);
Assert.Equal(1, tensor[0]);
Assert.Equal(2, tensor[1]);
Assert.Equal(3, tensor[2]);
Assert.Equal(4, tensor[3]);
Assert.Equal(5, tensor[4]);
Assert.Equal(6, tensor[5]);
Assert.Equal(7, tensor[6]);
Assert.Equal(8, tensor[7]);
Assert.Equal(9, tensor[8]);

dims = [3, -1];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(2, tensor.Rank);
Assert.Equal(3, tensor.Lengths[0]);
Assert.Equal(3, tensor.Lengths[1]);
Assert.Equal(2, tensor.Strides.Length);
Assert.Equal(3, tensor.Strides[0]);
Assert.Equal(1, tensor.Strides[1]);
Assert.Equal(1, tensor[0, 0]);
Assert.Equal(2, tensor[0, 1]);
Assert.Equal(3, tensor[0, 2]);
Assert.Equal(4, tensor[1, 0]);
Assert.Equal(5, tensor[1, 1]);
Assert.Equal(6, tensor[1, 2]);
Assert.Equal(7, tensor[2, 0]);
Assert.Equal(8, tensor[2, 1]);
Assert.Equal(9, tensor[2, 2]);

Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsReadOnlyTensorSpan(), [-1, -1]));

Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsReadOnlyTensorSpan(), [1, 2, 3, 4, 5]));

// Make sure reshape works correctly with 0 strides.
origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
tensor = origTensor.AsReadOnlyTensorSpan();
tensor = Tensor.Reshape(tensor, [1, 2]);
Assert.Equal(2, tensor.Rank);
Assert.Equal(1, tensor.Lengths[0]);
Assert.Equal(2, tensor.Lengths[1]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);

origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
tensor = origTensor.AsReadOnlyTensorSpan();
tensor = Tensor.Reshape(tensor, [2, 1]);
Assert.Equal(2, tensor.Rank);
Assert.Equal(2, tensor.Lengths[0]);
Assert.Equal(1, tensor.Lengths[1]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);

tensor = Tensor.Reshape(tensor, [1, 2, 1]);
Assert.Equal(3, tensor.Rank);
Assert.Equal(1, tensor.Lengths[0]);
Assert.Equal(2, tensor.Lengths[1]);
Assert.Equal(1, tensor.Lengths[2]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);
Assert.Equal(0, tensor.Strides[2]);

tensor = Tensor.Reshape(tensor, [1, 1, -1, 1]);
Assert.Equal(4, tensor.Rank);
Assert.Equal(1, tensor.Lengths[0]);
Assert.Equal(1, tensor.Lengths[1]);
Assert.Equal(2, tensor.Lengths[2]);
Assert.Equal(1, tensor.Lengths[3]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);
Assert.Equal(0, tensor.Strides[2]);
Assert.Equal(0, tensor.Strides[3]);
}
}
}
120 changes: 120 additions & 0 deletions src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2077,5 +2077,125 @@ public static void ToStringZeroDataTest()
""";
Assert.Equal(expected, tensor.ToString([2, 0, 2]));
}

[Fact]
public static void TensorReshapeTest()
{
int[] a = [1, 2, 3, 4, 5, 6, 7, 8, 9];
nint[] dims = [9];
var origTensor = Tensor.CreateFromShapeUninitialized<int>(dims.AsSpan(), false);
var span = a.AsTensorSpan(dims);
span.CopyTo(origTensor);
var tensor = origTensor.AsTensorSpan();

Assert.Equal(1, tensor.Rank);
Assert.Equal(9, tensor.Lengths[0]);
Assert.Equal(1, tensor.Strides.Length);
Assert.Equal(1, tensor.Strides[0]);
Assert.Equal(1, tensor[0]);
Assert.Equal(2, tensor[1]);
Assert.Equal(3, tensor[2]);
Assert.Equal(4, tensor[3]);
Assert.Equal(5, tensor[4]);
Assert.Equal(6, tensor[5]);
Assert.Equal(7, tensor[6]);
Assert.Equal(8, tensor[7]);
Assert.Equal(9, tensor[8]);

dims = [3, 3];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(2, tensor.Rank);
Assert.Equal(3, tensor.Lengths[0]);
Assert.Equal(3, tensor.Lengths[1]);
Assert.Equal(2, tensor.Strides.Length);
Assert.Equal(3, tensor.Strides[0]);
Assert.Equal(1, tensor.Strides[1]);
Assert.Equal(1, tensor[0, 0]);
Assert.Equal(2, tensor[0, 1]);
Assert.Equal(3, tensor[0, 2]);
Assert.Equal(4, tensor[1, 0]);
Assert.Equal(5, tensor[1, 1]);
Assert.Equal(6, tensor[1, 2]);
Assert.Equal(7, tensor[2, 0]);
Assert.Equal(8, tensor[2, 1]);
Assert.Equal(9, tensor[2, 2]);

dims = [-1];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(1, tensor.Rank);
Assert.Equal(9, tensor.Lengths[0]);
Assert.Equal(1, tensor.Strides.Length);
Assert.Equal(1, tensor.Strides[0]);
Assert.Equal(1, tensor[0]);
Assert.Equal(2, tensor[1]);
Assert.Equal(3, tensor[2]);
Assert.Equal(4, tensor[3]);
Assert.Equal(5, tensor[4]);
Assert.Equal(6, tensor[5]);
Assert.Equal(7, tensor[6]);
Assert.Equal(8, tensor[7]);
Assert.Equal(9, tensor[8]);

dims = [3, -1];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(2, tensor.Rank);
Assert.Equal(3, tensor.Lengths[0]);
Assert.Equal(3, tensor.Lengths[1]);
Assert.Equal(2, tensor.Strides.Length);
Assert.Equal(3, tensor.Strides[0]);
Assert.Equal(1, tensor.Strides[1]);
Assert.Equal(1, tensor[0, 0]);
Assert.Equal(2, tensor[0, 1]);
Assert.Equal(3, tensor[0, 2]);
Assert.Equal(4, tensor[1, 0]);
Assert.Equal(5, tensor[1, 1]);
Assert.Equal(6, tensor[1, 2]);
Assert.Equal(7, tensor[2, 0]);
Assert.Equal(8, tensor[2, 1]);
Assert.Equal(9, tensor[2, 2]);

Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsTensorSpan(), [-1, -1]));

Assert.Throws<ArgumentException>(() => Tensor.Reshape(origTensor.AsTensorSpan(), [1, 2, 3, 4, 5]));

// Make sure reshape works correctly with 0 strides.
origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
tensor = origTensor.AsTensorSpan();
tensor = Tensor.Reshape(tensor, [1, 2]);
Assert.Equal(2, tensor.Rank);
Assert.Equal(1, tensor.Lengths[0]);
Assert.Equal(2, tensor.Lengths[1]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);

origTensor = Tensor.CreateFromShape<int>((ReadOnlySpan<nint>)[2], [0], false);
tensor = origTensor.AsTensorSpan();
tensor = Tensor.Reshape(tensor, [2, 1]);
Assert.Equal(2, tensor.Rank);
Assert.Equal(2, tensor.Lengths[0]);
Assert.Equal(1, tensor.Lengths[1]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);

tensor = Tensor.Reshape(tensor, [1, 2, 1]);
Assert.Equal(3, tensor.Rank);
Assert.Equal(1, tensor.Lengths[0]);
Assert.Equal(2, tensor.Lengths[1]);
Assert.Equal(1, tensor.Lengths[2]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);
Assert.Equal(0, tensor.Strides[2]);

tensor = Tensor.Reshape(tensor, [1, 1, -1, 1]);
Assert.Equal(4, tensor.Rank);
Assert.Equal(1, tensor.Lengths[0]);
Assert.Equal(1, tensor.Lengths[1]);
Assert.Equal(2, tensor.Lengths[2]);
Assert.Equal(1, tensor.Lengths[3]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);
Assert.Equal(0, tensor.Strides[2]);
Assert.Equal(0, tensor.Strides[3]);
}
}
}
47 changes: 47 additions & 0 deletions src/libraries/System.Numerics.Tensors/tests/TensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2219,6 +2219,42 @@ public static void TensorReshapeTest()
Assert.Equal(8, tensor[2, 1]);
Assert.Equal(9, tensor[2, 2]);

dims = [-1];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(1, tensor.Rank);
Assert.Equal(9, tensor.Lengths[0]);
Assert.Equal(1, tensor.Strides.Length);
Assert.Equal(1, tensor.Strides[0]);
Assert.Equal(1, tensor[0]);
Assert.Equal(2, tensor[1]);
Assert.Equal(3, tensor[2]);
Assert.Equal(4, tensor[3]);
Assert.Equal(5, tensor[4]);
Assert.Equal(6, tensor[5]);
Assert.Equal(7, tensor[6]);
Assert.Equal(8, tensor[7]);
Assert.Equal(9, tensor[8]);

dims = [3, -1];
tensor = Tensor.Reshape(tensor, dims);
Assert.Equal(2, tensor.Rank);
Assert.Equal(3, tensor.Lengths[0]);
Assert.Equal(3, tensor.Lengths[1]);
Assert.Equal(2, tensor.Strides.Length);
Assert.Equal(3, tensor.Strides[0]);
Assert.Equal(1, tensor.Strides[1]);
Assert.Equal(1, tensor[0, 0]);
Assert.Equal(2, tensor[0, 1]);
Assert.Equal(3, tensor[0, 2]);
Assert.Equal(4, tensor[1, 0]);
Assert.Equal(5, tensor[1, 1]);
Assert.Equal(6, tensor[1, 2]);
Assert.Equal(7, tensor[2, 0]);
Assert.Equal(8, tensor[2, 1]);
Assert.Equal(9, tensor[2, 2]);

Assert.Throws<ArgumentException>(() => Tensor.Reshape(tensor, [-1, -1]));

Assert.Throws<ArgumentException>(() => Tensor.Reshape(tensor, [1, 2, 3, 4, 5]));

// Make sure reshape works correctly with 0 strides.
Expand Down Expand Up @@ -2246,6 +2282,17 @@ public static void TensorReshapeTest()
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);
Assert.Equal(0, tensor.Strides[2]);

tensor = Tensor.Reshape(tensor, [1, 1, -1, 1]);
Assert.Equal(4, tensor.Rank);
Assert.Equal(1, tensor.Lengths[0]);
Assert.Equal(1, tensor.Lengths[1]);
Assert.Equal(2, tensor.Lengths[2]);
Assert.Equal(1, tensor.Lengths[3]);
Assert.Equal(0, tensor.Strides[0]);
Assert.Equal(0, tensor.Strides[1]);
Assert.Equal(0, tensor.Strides[2]);
Assert.Equal(0, tensor.Strides[3]);
}

[Fact]
Expand Down
Loading