diff --git a/Source/SuperLinq/FallbackIfEmpty.cs b/Source/SuperLinq/FallbackIfEmpty.cs index 5a7d93f3..ccf10731 100644 --- a/Source/SuperLinq/FallbackIfEmpty.cs +++ b/Source/SuperLinq/FallbackIfEmpty.cs @@ -40,7 +40,10 @@ public static IEnumerable FallbackIfEmpty(this IEnumerable source, IEnu Guard.IsNotNull(source); Guard.IsNotNull(fallback); - return Core(source, fallback); + return source.TryGetCollectionCount() is not null + && fallback.TryGetCollectionCount() is not null + ? new FallbackIfEmptyCollectionIterator(source, fallback) + : Core(source, fallback); static IEnumerable Core(IEnumerable source, IEnumerable fallback) { @@ -48,8 +51,11 @@ static IEnumerable Core(IEnumerable source, IEnumerable fallback) { if (e.MoveNext()) { - do { yield return e.Current; } - while (e.MoveNext()); + do + { + yield return e.Current; + } while (e.MoveNext()); + yield break; } } @@ -58,4 +64,28 @@ static IEnumerable Core(IEnumerable source, IEnumerable fallback) yield return item; } } + + private sealed class FallbackIfEmptyCollectionIterator : CollectionIterator + { + private readonly IEnumerable _source; + private readonly IEnumerable _fallback; + + public FallbackIfEmptyCollectionIterator(IEnumerable source, IEnumerable fallback) + { + _source = source; + _fallback = fallback; + } + + public override int Count => + _source.GetCollectionCount() == 0 + ? _fallback.Count() + : _source.GetCollectionCount(); + + protected override IEnumerable GetEnumerable() + { + return _source.GetCollectionCount() == 0 + ? _fallback + : _source; + } + } } diff --git a/Tests/SuperLinq.Test/FallbackIfEmptyTest.cs b/Tests/SuperLinq.Test/FallbackIfEmptyTest.cs index df6faba6..c3a4f5c6 100644 --- a/Tests/SuperLinq.Test/FallbackIfEmptyTest.cs +++ b/Tests/SuperLinq.Test/FallbackIfEmptyTest.cs @@ -5,16 +5,60 @@ public class FallbackIfEmptyTest [Fact] public void FallbackIfEmptyWithEmptySequence() { - using var source = Enumerable.Empty().AsTestingSequence(maxEnumerations: 2); + using var source = Seq().AsTestingSequence(); + source.FallbackIfEmpty(12).AssertSequenceEqual(12); + } + + [Fact] + public void FallbackIfEmptyWithCollectionSequence() + { + using var source = Seq().AsTestingCollection(); source.FallbackIfEmpty(12).AssertSequenceEqual(12); - source.FallbackIfEmpty(12, 23).AssertSequenceEqual(12, 23); } [Fact] public void FallbackIfEmptyWithNotEmptySequence() { - using var source = Seq(1).AsTestingSequence(maxEnumerations: 2); - source.FallbackIfEmpty(12).AssertSequenceEqual(1); - source.FallbackIfEmpty(12, 23).AssertSequenceEqual(1); + using var source = Seq(1).AsTestingSequence(); + source.FallbackIfEmpty(new BreakingSequence()).AssertSequenceEqual(1); + } + + [Fact] + public void FallbackIfEmptyWithNotEmptyCollectionSequence() + { + using var source = Seq(1).AsTestingCollection(); + source.FallbackIfEmpty(new BreakingSequence()).AssertSequenceEqual(1); + } + + [Fact] + public void AssertFallbackIfEmptyCollectionBehaviorOnEmptyCollection() + { + using var source = Seq().AsBreakingCollection(); + using var fallback = Enumerable.Range(0, 10_000).AsBreakingCollection(); + + var result = source.FallbackIfEmpty(fallback); + result.AssertCollectionErrorChecking(10_000); + } + + [Fact] + public void AssertFallbackIfEmptyCollectionBehaviorOnNonEmptyCollection() + { + using var source = Enumerable.Range(0, 10_000).AsBreakingCollection(); + using var fallback = new BreakingCollection(); + + var result = source.FallbackIfEmpty(fallback); + result.AssertCollectionErrorChecking(10_000); + } + + [Fact] + public void FallbackIfEmptyUsesCollectionCountAtIterationTime() + { + var stack = new Stack(); + + var result = stack.FallbackIfEmpty(4); + result.AssertSequenceEqual(4); + + stack.Push(1); + result.AssertSequenceEqual(1); } }