Skip to content

Commit

Permalink
Merge pull request dotnet/corefx#13942 from jamesqo/select-many
Browse files Browse the repository at this point in the history
Specialize the single-selector overload of SelectMany.

Commit migrated from dotnet/corefx@5bf69d1
  • Loading branch information
VSadov authored Dec 28, 2016
2 parents afddcaf + 98c06d7 commit ab25b8f
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 12 deletions.
137 changes: 125 additions & 12 deletions src/libraries/System.Linq/src/System/Linq/SelectMany.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Diagnostics;

namespace System.Linq
{
Expand All @@ -20,18 +21,7 @@ public static IEnumerable<TResult> SelectMany<TSource, TResult>(this IEnumerable
throw Error.ArgumentNull(nameof(selector));
}

return SelectManyIterator(source, selector);
}

private static IEnumerable<TResult> SelectManyIterator<TSource, TResult>(IEnumerable<TSource> source, Func<TSource, IEnumerable<TResult>> selector)
{
foreach (TSource element in source)
{
foreach (TResult subElement in selector(element))
{
yield return subElement;
}
}
return new SelectManySingleSelectorIterator<TSource, TResult>(source, selector);
}

public static IEnumerable<TResult> SelectMany<TSource, TResult>(this IEnumerable<TSource> source, Func<TSource, int, IEnumerable<TResult>> selector)
Expand Down Expand Up @@ -133,5 +123,128 @@ private static IEnumerable<TResult> SelectManyIterator<TSource, TCollection, TRe
}
}
}

private sealed class SelectManySingleSelectorIterator<TSource, TResult> : Iterator<TResult>, IIListProvider<TResult>
{
private readonly IEnumerable<TSource> _source;
private readonly Func<TSource, IEnumerable<TResult>> _selector;
private IEnumerator<TSource> _sourceEnumerator;
private IEnumerator<TResult> _subEnumerator;

internal SelectManySingleSelectorIterator(IEnumerable<TSource> source, Func<TSource, IEnumerable<TResult>> selector)
{
Debug.Assert(source != null);
Debug.Assert(selector != null);

_source = source;
_selector = selector;
}

public override Iterator<TResult> Clone()
{
return new SelectManySingleSelectorIterator<TSource, TResult>(_source, _selector);
}

public override void Dispose()
{
if (_subEnumerator != null)
{
_subEnumerator.Dispose();
_subEnumerator = null;
}

if (_sourceEnumerator != null)
{
_sourceEnumerator.Dispose();
_sourceEnumerator = null;
}

base.Dispose();
}

public int GetCount(bool onlyIfCheap)
{
if (onlyIfCheap)
{
return -1;
}

int count = 0;

foreach (TSource element in _source)
{
checked
{
count += _selector(element).Count();
}
}

return count;
}

public override bool MoveNext()
{
switch (_state)
{
case 1:
// Retrieve the source enumerator.
_sourceEnumerator = _source.GetEnumerator();
_state = 2;
goto case 2;
case 2:
// Take the next element from the source enumerator.
if (!_sourceEnumerator.MoveNext())
{
break;
}

TSource element = _sourceEnumerator.Current;

// Project it into a sub-collection and get its enumerator.
_subEnumerator = _selector(element).GetEnumerator();
_state = 3;
goto case 3;
case 3:
// Take the next element from the sub-collection and yield.
if (!_subEnumerator.MoveNext())
{
_subEnumerator.Dispose();
_subEnumerator = null;
_state = 2;
goto case 2;
}

_current = _subEnumerator.Current;
return true;
}

Dispose();
return false;
}

public TResult[] ToArray()
{
var builder = new LargeArrayBuilder<TResult>(initialize: true);

foreach (TSource element in _source)
{
builder.AddRange(_selector(element));
}

return builder.ToArray();
}

public List<TResult> ToList()
{
var list = new List<TResult>();

foreach (TSource element in _source)
{
list.AddRange(_selector(element));
}

return list;
}
}
}
}
24 changes: 24 additions & 0 deletions src/libraries/System.Linq/tests/EnumerableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,29 @@ protected static List<Func<IEnumerable<T>, IEnumerable<T>>> IdentityTransforms<T
e => e.Where(i => true)
};
}

protected class DelegateBasedEnumerator<T> : IEnumerator<T>
{
public Func<bool> MoveNextWorker { get; set; }
public Func<T> CurrentWorker { get; set; }
public Action DisposeWorker { get; set; }
public Func<object> NonGenericCurrentWorker { get; set; }
public Action ResetWorker { get; set; }

public T Current => CurrentWorker();
public bool MoveNext() => MoveNextWorker();
public void Dispose() => DisposeWorker();
void IEnumerator.Reset() => ResetWorker();
object IEnumerator.Current => NonGenericCurrentWorker();
}

protected class DelegateBasedEnumerable<T> : IEnumerable<T>
{
public Func<IEnumerator<T>> GetEnumeratorWorker { get; set; }
public Func<IEnumerator> NonGenericGetEnumeratorWorker { get; set; }

public IEnumerator<T> GetEnumerator() => GetEnumeratorWorker();
IEnumerator IEnumerable.GetEnumerator() => NonGenericGetEnumeratorWorker();
}
}
}
116 changes: 116 additions & 0 deletions src/libraries/System.Linq/tests/SelectManyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -375,5 +375,121 @@ public void ForcedToEnumeratorDoesntEnumerateIndexedResultSel()
var en = iterator as IEnumerator<int>;
Assert.False(en != null && en.MoveNext());
}

[Theory]
[MemberData(nameof(ParameterizedTestsData))]
public void ParameterizedTests(IEnumerable<int> source, Func<int, IEnumerable<int>> selector)
{
var expected = source.Select(i => selector(i)).Aggregate((l, r) => l.Concat(r));
var actual = source.SelectMany(selector);

Assert.Equal(expected, actual);
Assert.Equal(expected.Count(), actual.Count()); // SelectMany may employ an optimized Count implementation.
Assert.Equal(expected.ToArray(), actual.ToArray());
Assert.Equal(expected.ToList(), actual.ToList());
}

public static IEnumerable<object[]> ParameterizedTestsData()
{
for (int i = 1; i <= 20; i++)
{
Func<int, IEnumerable<int>> selector = n => Enumerable.Range(i, n);
yield return new object[] { Enumerable.Range(1, i), selector };
}
}

[Theory]
[MemberData(nameof(DisposeAfterEnumerationData))]
public void DisposeAfterEnumeration(int sourceLength, int subLength)
{
int sourceState = 0;
int subIndex = 0; // Index within the arrays the sub-collection is supposed to be at.
int[] subState = new int[sourceLength];

bool sourceDisposed = false;
bool[] subCollectionDisposed = new bool[sourceLength];

var sourceEnumerator = new DelegateBasedEnumerator<int>
{
MoveNextWorker = () => ++sourceState <= sourceLength,
CurrentWorker = () => 0,
DisposeWorker = () => sourceDisposed = true
};

var source = new DelegateBasedEnumerable<int>
{
GetEnumeratorWorker = () => sourceEnumerator
};

var subEnumerator = new DelegateBasedEnumerator<int>
{
// MoveNext: Return true subLength times.
// Dispose: Record that Dispose was called & move to the next index.
MoveNextWorker = () => ++subState[subIndex] <= subLength,
CurrentWorker = () => subState[subIndex],
DisposeWorker = () => subCollectionDisposed[subIndex++] = true
};

var subCollection = new DelegateBasedEnumerable<int>
{
GetEnumeratorWorker = () => subEnumerator
};

var iterator = source.SelectMany(_ => subCollection);

int index = 0; // How much have we gone into the iterator?
IEnumerator<int> e = iterator.GetEnumerator();

using (e)
{
while (e.MoveNext())
{
int item = e.Current;

Assert.Equal(subState[subIndex], item); // Verify Current.
Assert.Equal(index / subLength, subIndex);

Assert.False(sourceDisposed); // Not yet.

// This represents whehter the sub-collection we're iterating thru right now
// has been disposed. Also not yet.
Assert.False(subCollectionDisposed[subIndex]);

// However, all of the sub-collections before us should have been disposed.
// Their indices should also be maxed out.
Assert.All(subState.Take(subIndex), s => Assert.Equal(subLength + 1, s));
Assert.All(subCollectionDisposed.Take(subIndex), t => Assert.True(t));

index++;
}
}

Assert.True(sourceDisposed);
Assert.Equal(sourceLength, subIndex);
Assert.All(subState, s => Assert.Equal(subLength + 1, s));
Assert.All(subCollectionDisposed, t => Assert.True(t));

// Make sure the iterator's enumerator has been disposed properly.
Assert.Equal(0, e.Current); // Default value.
Assert.False(e.MoveNext());
Assert.Equal(0, e.Current);
}

public static IEnumerable<object[]> DisposeAfterEnumerationData()
{
int[] lengths = { 1, 2, 3, 5, 8, 13, 21, 34 };

return lengths.SelectMany(l => lengths, (l1, l2) => new object[] { l1, l2 });
}

[Theory]
[InlineData(new[] { int.MaxValue, 1 })]
[InlineData(new[] { 2, int.MaxValue - 1 })]
[InlineData(new[] { 123, 456, int.MaxValue - 100000, 123456 })]
public void ThrowOverflowExceptionOnConstituentLargeCounts(int[] counts)
{
var iterator = counts.SelectMany(c => Enumerable.Range(1, c));
Assert.Throws<OverflowException>(() => iterator.Count());
}
}
}

0 comments on commit ab25b8f

Please sign in to comment.