Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,29 @@ public static Tensor<T> CreateFromShapeUninitialized<T>(scoped ReadOnlySpan<nint
/// <returns></returns>
public static ref readonly TensorSpan<T> FillGaussianNormalDistribution<T>(in TensorSpan<T> destination, Random? random = null) where T : IFloatingPoint<T>
{
Span<T> span = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
random ??= Random.Shared;

for (int i = 0; i < span.Length; i++)
if (destination.IsDense)
{
double u1 = 1.0 - random.NextDouble();
double u2 = 1.0 - random.NextDouble();
span[i] = T.CreateChecked(Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Sin(2.0 * Math.PI * u2));
Span<T> span = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination.FlattenedLength);

for (int i = 0; i < span.Length; i++)
{
double u1 = 1.0 - random.NextDouble();
double u2 = 1.0 - random.NextDouble();
span[i] = T.CreateChecked(Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Sin(2.0 * Math.PI * u2));
}
}
else
{
TensorSpan<T>.Enumerator enumerator = destination.GetEnumerator();

while (enumerator.MoveNext())
{
double u1 = 1.0 - random.NextDouble();
double u2 = 1.0 - random.NextDouble();
enumerator.Current = T.CreateChecked(Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Sin(2.0 * Math.PI * u2));
}
}

return ref destination;
Expand All @@ -370,10 +385,25 @@ public static ref readonly TensorSpan<T> FillGaussianNormalDistribution<T>(in Te
/// <returns></returns>
public static ref readonly TensorSpan<T> FillUniformDistribution<T>(in TensorSpan<T> destination, Random? random = null) where T : IFloatingPoint<T>
{
Span<T> span = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
random ??= Random.Shared;
for (int i = 0; i < span.Length; i++)
span[i] = T.CreateChecked(random.NextDouble());

if (destination.IsDense)
{
Span<T> span = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination.FlattenedLength);
for (int i = 0; i < span.Length; i++)
{
span[i] = T.CreateChecked(random.NextDouble());
}
}
else
{
TensorSpan<T>.Enumerator enumerator = destination.GetEnumerator();

while (enumerator.MoveNext())
{
enumerator.Current = T.CreateChecked(random.NextDouble());
}
}

return ref destination;
}
Expand Down Expand Up @@ -1549,12 +1579,33 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
nint newSize = TensorPrimitives.Product(lengths);
T[] values = tensor.IsPinned ? GC.AllocateArray<T>((int)newSize) : (new T[newSize]);
Tensor<T> output = Create(values, lengths, []);
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), tensor._values.Length - tensor._start);
Span<T> ospan = MemoryMarshal.CreateSpan(ref output.AsTensorSpan()._reference, (int)output.FlattenedLength);
if (newSize >= span.Length)
span.CopyTo(ospan);

if (tensor.IsDense)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), tensor._values.Length - tensor._start);
Span<T> ospan = MemoryMarshal.CreateSpan(ref output.AsTensorSpan()._reference, (int)output.FlattenedLength);
if (newSize >= span.Length)
{
span.CopyTo(ospan);
}
else
{
span.Slice(0, ospan.Length).CopyTo(ospan);
}
}
else
span.Slice(0, ospan.Length).CopyTo(ospan);
{
nint copyLength = Math.Min(tensor.FlattenedLength, newSize);
ReadOnlyTensorSpan<T>.Enumerator enumerator = tensor.AsReadOnlyTensorSpan().GetEnumerator();
Span<T> ospan = MemoryMarshal.CreateSpan(ref output.AsTensorSpan()._reference, (int)output.FlattenedLength);

for (nint i = 0; i < copyLength; i++)
{
bool moved = enumerator.MoveNext();
Debug.Assert(moved);
ospan[(int)i] = enumerator.Current;
}
}

return output;
}
Expand Down Expand Up @@ -1589,12 +1640,33 @@ public static void ResizeTo<T>(scoped in TensorSpan<T> tensor, in TensorSpan<T>
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
public static void ResizeTo<T>(scoped in ReadOnlyTensorSpan<T> tensor, in TensorSpan<T> destination)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
if (tensor.IsDense && destination.IsDense)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor.FlattenedLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination.FlattenedLength);
if (ospan.Length >= span.Length)
{
span.CopyTo(ospan);
}
else
{
span.Slice(0, ospan.Length).CopyTo(ospan);
}
}
else
span.Slice(0, ospan.Length).CopyTo(ospan);
{
nint copyLength = Math.Min(tensor.FlattenedLength, destination.FlattenedLength);
ReadOnlyTensorSpan<T>.Enumerator srcEnumerator = tensor.GetEnumerator();
TensorSpan<T>.Enumerator dstEnumerator = destination.GetEnumerator();

for (nint i = 0; i < copyLength; i++)
{
bool srcMoved = srcEnumerator.MoveNext();
bool dstMoved = dstEnumerator.MoveNext();
Debug.Assert(srcMoved && dstMoved);
dstEnumerator.Current = srcEnumerator.Current;
}
}
}
#endregion

Expand Down Expand Up @@ -1684,10 +1756,7 @@ public static ref readonly TensorSpan<T> ReverseDimension<T>(scoped in ReadOnlyT
public static bool SequenceEqual<T>(this scoped in TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<T> other)
where T : IEquatable<T>?
{
return tensor.FlattenedLength == other.FlattenedLength
&& tensor._shape.LinearLength == other._shape.LinearLength
&& tensor.Lengths.SequenceEqual(other.Lengths)
&& MemoryMarshal.CreateReadOnlySpan(in tensor.GetPinnableReference(), (int)tensor._shape.LinearLength).SequenceEqual(MemoryMarshal.CreateReadOnlySpan(in other.GetPinnableReference(), (int)other._shape.LinearLength));
return ((ReadOnlyTensorSpan<T>)tensor).SequenceEqual(other);
}

/// <summary>
Expand All @@ -1696,10 +1765,32 @@ public static bool SequenceEqual<T>(this scoped in TensorSpan<T> tensor, scoped
public static bool SequenceEqual<T>(this scoped in ReadOnlyTensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<T> other)
where T : IEquatable<T>?
{
return tensor.FlattenedLength == other.FlattenedLength
&& tensor._shape.LinearLength == other._shape.LinearLength
&& tensor.Lengths.SequenceEqual(other.Lengths)
&& MemoryMarshal.CreateReadOnlySpan(in tensor.GetPinnableReference(), (int)tensor._shape.LinearLength).SequenceEqual(MemoryMarshal.CreateReadOnlySpan(in other.GetPinnableReference(), (int)other._shape.LinearLength));
if (tensor.FlattenedLength != other.FlattenedLength
|| !tensor.Lengths.SequenceEqual(other.Lengths))
{
return false;
}

if (tensor.IsDense && other.IsDense)
{
return MemoryMarshal.CreateReadOnlySpan(in tensor.GetPinnableReference(), (int)tensor.FlattenedLength).SequenceEqual(MemoryMarshal.CreateReadOnlySpan(in other.GetPinnableReference(), (int)other.FlattenedLength));
}

ReadOnlyTensorSpan<T>.Enumerator enumerator1 = tensor.GetEnumerator();
ReadOnlyTensorSpan<T>.Enumerator enumerator2 = other.GetEnumerator();

while (enumerator1.MoveNext())
{
bool moved = enumerator2.MoveNext();
Debug.Assert(moved);

if (!EqualityComparer<T>.Default.Equals(enumerator1.Current, enumerator2.Current))
{
return false;
}
}

return true;
}
#endregion

Expand Down Expand Up @@ -3528,8 +3619,21 @@ public static ref readonly TensorSpan<int> ILogB<T>(scoped in ReadOnlyTensorSpan
public static nint IndexOfMax<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape.LinearLength);
return TensorPrimitives.IndexOfMax(span);
if (x.IsDense)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x.FlattenedLength);
return TensorPrimitives.IndexOfMax(span);
}

return IndexOfMaxFallback(x);
}

private static nint IndexOfMaxFallback<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
T[] flat = new T[x.FlattenedLength];
x.FlattenTo(flat);
return TensorPrimitives.IndexOfMax<T>(flat);
}

#endregion
Expand All @@ -3540,8 +3644,21 @@ public static nint IndexOfMax<T>(scoped in ReadOnlyTensorSpan<T> x)
public static nint IndexOfMaxMagnitude<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape.LinearLength);
return TensorPrimitives.IndexOfMaxMagnitude(span);
if (x.IsDense)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x.FlattenedLength);
return TensorPrimitives.IndexOfMaxMagnitude(span);
}

return IndexOfMaxMagnitudeFallback(x);
}

private static nint IndexOfMaxMagnitudeFallback<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
T[] flat = new T[x.FlattenedLength];
x.FlattenTo(flat);
return TensorPrimitives.IndexOfMaxMagnitude<T>(flat);
}
#endregion

Expand All @@ -3551,8 +3668,21 @@ public static nint IndexOfMaxMagnitude<T>(scoped in ReadOnlyTensorSpan<T> x)
public static nint IndexOfMin<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape.LinearLength);
return TensorPrimitives.IndexOfMin(span);
if (x.IsDense)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x.FlattenedLength);
return TensorPrimitives.IndexOfMin(span);
}

return IndexOfMinFallback(x);
}

private static nint IndexOfMinFallback<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
T[] flat = new T[x.FlattenedLength];
x.FlattenTo(flat);
return TensorPrimitives.IndexOfMin<T>(flat);
}
#endregion

Expand All @@ -3564,8 +3694,21 @@ public static nint IndexOfMin<T>(scoped in ReadOnlyTensorSpan<T> x)
public static nint IndexOfMinMagnitude<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape.LinearLength);
return TensorPrimitives.IndexOfMinMagnitude(span);
if (x.IsDense)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref x._reference, (int)x.FlattenedLength);
return TensorPrimitives.IndexOfMinMagnitude(span);
}

return IndexOfMinMagnitudeFallback(x);
}

private static nint IndexOfMinMagnitudeFallback<T>(scoped in ReadOnlyTensorSpan<T> x)
where T : INumber<T>
{
T[] flat = new T[x.FlattenedLength];
x.FlattenTo(flat);
return TensorPrimitives.IndexOfMinMagnitude<T>(flat);
}
#endregion

Expand Down
Loading
Loading