diff --git a/MoreLinq.Test/CartesianTest.cs b/MoreLinq.Test/CartesianTest.cs index 36dcaebcb..f46548808 100644 --- a/MoreLinq.Test/CartesianTest.cs +++ b/MoreLinq.Test/CartesianTest.cs @@ -154,10 +154,8 @@ public void TestCartesianProductCombinations() [Test] public void TestEmptyCartesianEvaluation() { - using var sequence = Enumerable.Range(0, 5).AsTestingSequence(); - - var resultA = sequence.Cartesian(Enumerable.Empty(), (a, b) => new { A = a, B = b }); - var resultB = Enumerable.Empty().Cartesian(sequence, (a, b) => new { A = a, B = b }); + var resultA = Enumerable.Range(0, 5).AsTestingSequence().Cartesian(Enumerable.Empty(), (a, b) => new { A = a, B = b }); + var resultB = Enumerable.Empty().Cartesian(Enumerable.Range(0, 5).AsTestingSequence(), (a, b) => new { A = a, B = b }); var resultC = Enumerable.Empty().Cartesian(Enumerable.Empty(), (a, b) => new { A = a, B = b }); Assert.AreEqual(0, resultA.Count()); diff --git a/MoreLinq.Test/MemoizeTest.cs b/MoreLinq.Test/MemoizeTest.cs index 80efcabf4..89f1a58fd 100644 --- a/MoreLinq.Test/MemoizeTest.cs +++ b/MoreLinq.Test/MemoizeTest.cs @@ -183,33 +183,7 @@ from list in lists } [Test] - public static void MemoizeRestartsAfterDisposal() - { - var starts = 0; - - IEnumerable TestSequence() - { - starts++; - yield return 1; - yield return 2; - } - - var memoized = TestSequence().Memoize(); - - void Run() - { - using ((IDisposable) memoized) - memoized.Take(1).Consume(); - } - - Run(); - Assert.That(starts, Is.EqualTo(1)); - Run(); - Assert.That(starts, Is.EqualTo(2)); - } - - [Test] - public static void MemoizeIteratorThrowsWhenCacheDisposedDuringIteration() + public static void MemoizeIteratorThrowsWhenDisposedDuringIteration() { var sequence = Enumerable.Range(1, 10); var memoized = sequence.Memoize(); @@ -221,7 +195,7 @@ public static void MemoizeIteratorThrowsWhenCacheDisposedDuringIteration() disposable.Dispose(); var e = Assert.Throws(() => reader.Read()); - Assert.That(e.ObjectName, Is.EqualTo("MemoizedEnumerable")); + Assert.That(e.ObjectName, Is.EqualTo("MemoizedEnumerable`1")); } [Test] @@ -259,67 +233,8 @@ IEnumerable TestSequence() var e2 = Assert.Throws(() => r2.Read()); Assert.That(e2, Is.SameAs(error)); } - using (var r1 = memoized.Read()) - Assert.That(r1.Read(), Is.EqualTo(123)); - } - - [Test] - public void MemoizeRethrowsErrorDuringIterationStartToAllIteratorsUntilDisposed() - { - var error = new Exception("This is a test exception."); - - var i = 0; - IEnumerable TestSequence() - { - if (0 == i++) // throw at start for first iteration only - throw error; - yield return 42; - } - - var xs = new DisposalTrackingSequence(TestSequence()); - var memoized = xs.Memoize(); - using ((IDisposable) memoized) - using (var r1 = memoized.Read()) - using (var r2 = memoized.Read()) - { - var e1 = Assert.Throws(() => r1.Read()); - Assert.That(e1, Is.SameAs(error)); - - Assert.That(xs.IsDisposed, Is.True); - - var e2 = Assert.Throws(() => r2.Read()); - Assert.That(e2, Is.SameAs(error)); - } - - using (var r1 = memoized.Read()) - using (var r2 = memoized.Read()) - Assert.That(r1.Read(), Is.EqualTo(r2.Read())); - } - - [Test] - public void MemoizeRethrowsErrorDuringFirstIterationStartToAllIterationsUntilDisposed() - { - var error = new Exception("An error on the first call!"); - var obj = new object(); - var calls = 0; - var source = Delegate.Enumerable(() => 0 == calls++ - ? throw error - : Enumerable.Repeat(obj, 1).GetEnumerator()); - - var memo = source.Memoize(); - - for (var i = 0; i < 2; i++) - { - var e = Assert.Throws(() => memo.First()); - Assert.That(e, Is.SameAs(error)); - } - - ((IDisposable) memo).Dispose(); - Assert.That(memo.Single(), Is.EqualTo(obj)); } - // TODO Consolidate with MoreLinq.Test.TestingSequence? - sealed class DisposalTrackingSequence : IEnumerable, IDisposable { readonly IEnumerable _sequence; diff --git a/MoreLinq.Test/TestingSequence.cs b/MoreLinq.Test/TestingSequence.cs index 9a6e3ed7f..2fd444724 100644 --- a/MoreLinq.Test/TestingSequence.cs +++ b/MoreLinq.Test/TestingSequence.cs @@ -69,14 +69,11 @@ public IEnumerator GetEnumerator() _disposed = false; enumerator.Disposed += delegate { - Assert.That(_disposed, Is.False, "LINQ operators should not dispose a sequence more than once."); _disposed = true; }; - var ended = false; enumerator.MoveNextCalled += (_, moved) => { - Assert.That(ended, Is.False, "LINQ operators should not continue iterating a sequence that has terminated."); - ended = !moved; + Assert.That(_disposed, Is.False, "LINQ operators should not call MoveNext() on a disposed sequence."); MoveNextCallCount++; }; _sequence = null; diff --git a/MoreLinq/Experimental/Memoize.cs b/MoreLinq/Experimental/Memoize.cs index 0888ae455..8d4d936d6 100644 --- a/MoreLinq/Experimental/Memoize.cs +++ b/MoreLinq/Experimental/Memoize.cs @@ -66,96 +66,70 @@ public static IEnumerable Memoize(this IEnumerable source) sealed class MemoizedEnumerable : IEnumerable, IDisposable { - List? _cache; + readonly List _cache = new(); readonly object _locker; - readonly IEnumerable _source; - IEnumerator? _sourceEnumerator; + + readonly Lazy> _sourceEnumerator; + bool _running = true; + bool _isDisposed = false; + int? _errorIndex; ExceptionDispatchInfo? _error; public MemoizedEnumerable(IEnumerable sequence) { - _source = sequence ?? throw new ArgumentNullException(nameof(sequence)); + if (sequence == null) throw new ArgumentNullException(nameof(sequence)); + _sourceEnumerator = new Lazy>( + sequence.GetEnumerator, + System.Threading.LazyThreadSafetyMode.ExecutionAndPublication); _locker = new object(); } public IEnumerator GetEnumerator() { - if (_cache == null) + var index = 0; + + while (true) { + if (_isDisposed) + throw new ObjectDisposedException(GetType().Name); + + T current; lock (_locker) { - if (_cache == null) + if (index >= _cache.Count) { - _error?.Throw(); + if (index == _errorIndex) + Assume.NotNull(_error).Throw(); + bool moved; try { - var cache = new List(); // for exception safety, allocate then... - _sourceEnumerator = _source.GetEnumerator(); // (because this can fail) - _cache = cache; // ...commit to state + moved = _running && _sourceEnumerator.Value.MoveNext(); } catch (Exception ex) { _error = ExceptionDispatchInfo.Capture(ex); + _errorIndex = index; + _sourceEnumerator.Value.Dispose(); throw; } - } - } - } - return _(); IEnumerator _() - { - var index = 0; - - while (true) - { - T current; - lock (_locker) - { - if (_cache == null) // Cache disposed during iteration? - throw new ObjectDisposedException(nameof(MemoizedEnumerable)); - - if (index >= _cache.Count) + if (!moved) { - if (index == _errorIndex) - Assume.NotNull(_error).Throw(); - - if (_sourceEnumerator == null) - break; - - bool moved; - try - { - moved = _sourceEnumerator.MoveNext(); - } - catch (Exception ex) - { - _error = ExceptionDispatchInfo.Capture(ex); - _errorIndex = index; - _sourceEnumerator.Dispose(); - _sourceEnumerator = null; - throw; - } - - if (moved) - { - _cache.Add(_sourceEnumerator.Current); - } - else - { - _sourceEnumerator.Dispose(); - _sourceEnumerator = null; - break; - } + _running = false; + _sourceEnumerator.Value.Dispose(); + yield break; } - current = _cache[index]; + _cache.Add(_sourceEnumerator.Value.Current); } - yield return current; - index++; + current = _cache[index]; } + + yield return current; + index++; } } @@ -166,10 +140,10 @@ public void Dispose() lock (_locker) { _error = null; - _cache = null; _errorIndex = null; - _sourceEnumerator?.Dispose(); - _sourceEnumerator = null; + _sourceEnumerator.Value.Dispose(); + + _isDisposed = true; } } }