From a54d9e9b991c9059f85150f2e69b6507f103b322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emmanuel=20Andr=C3=A9?= <2341261+manandre@users.noreply.github.com> Date: Tue, 25 Jun 2024 02:58:26 +0200 Subject: [PATCH] STJ: Dispose enumerator on exception (#100194) * STJ: Dispose enumerator on exception * Avoid code duplication * Rework fix * Remove useless Disposable field * Apply fix on all collection converters * Remove duplicate assignments * Skip fix for no-op Dispose implementation * Move IEnumerator disposal to WriteCore method. --------- Co-authored-by: Eirik Tsarpalis --- .../Collection/DictionaryDefaultConverter.cs | 3 +- .../Collection/IDictionaryConverter.cs | 3 +- .../Collection/IEnumerableConverter.cs | 3 +- .../Collection/IEnumerableDefaultConverter.cs | 3 +- .../Collection/StackOrQueueConverter.cs | 3 +- .../JsonConverterOfT.WriteCore.cs | 47 ++++++++++--------- .../Metadata/JsonTypeInfoOfT.WriteHelpers.cs | 1 + .../Text/Json/Serialization/WriteStack.cs | 6 +-- .../Text/Json/ThrowHelper.Serialization.cs | 6 +-- .../CollectionTests.Generic.Write.cs | 28 +++++++++++ .../Serialization/CollectionTests.cs | 2 + 11 files changed, 68 insertions(+), 37 deletions(-) diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/DictionaryDefaultConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/DictionaryDefaultConverter.cs index 27122579aeea39..659461a5e9fe69 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/DictionaryDefaultConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/DictionaryDefaultConverter.cs @@ -28,6 +28,7 @@ protected internal override bool OnWriteResume( if (state.Current.CollectionEnumerator == null) { enumerator = value.GetEnumerator(); + state.Current.CollectionEnumerator = enumerator; if (!enumerator.MoveNext()) { enumerator.Dispose(); @@ -47,7 +48,6 @@ protected internal override bool OnWriteResume( { if (ShouldFlush(ref state, writer)) { - state.Current.CollectionEnumerator = enumerator; return false; } @@ -61,7 +61,6 @@ protected internal override bool OnWriteResume( TValue element = enumerator.Current.Value; if (!_valueConverter.TryWrite(writer, element, options, ref state)) { - state.Current.CollectionEnumerator = enumerator; return false; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IDictionaryConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IDictionaryConverter.cs index 06cb6613d816fd..4ea31cc48ad34a 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IDictionaryConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IDictionaryConverter.cs @@ -45,6 +45,7 @@ protected internal override bool OnWriteResume(Utf8JsonWriter writer, TDictionar if (state.Current.CollectionEnumerator == null) { enumerator = value.GetEnumerator(); + state.Current.CollectionEnumerator = enumerator; if (!enumerator.MoveNext()) { return true; @@ -62,7 +63,6 @@ protected internal override bool OnWriteResume(Utf8JsonWriter writer, TDictionar { if (ShouldFlush(ref state, writer)) { - state.Current.CollectionEnumerator = enumerator; return false; } @@ -87,7 +87,6 @@ protected internal override bool OnWriteResume(Utf8JsonWriter writer, TDictionar object? element = enumerator.Value; if (!_valueConverter.TryWrite(writer, element, options, ref state)) { - state.Current.CollectionEnumerator = enumerator; return false; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableConverter.cs index c689ed197344dd..c156e58f812cd2 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableConverter.cs @@ -46,6 +46,7 @@ protected override bool OnWriteResume( if (state.Current.CollectionEnumerator == null) { enumerator = value.GetEnumerator(); + state.Current.CollectionEnumerator = enumerator; if (!enumerator.MoveNext()) { return true; @@ -61,14 +62,12 @@ protected override bool OnWriteResume( { if (ShouldFlush(ref state, writer)) { - state.Current.CollectionEnumerator = enumerator; return false; } object? element = enumerator.Current; if (!converter.TryWrite(writer, element, options, ref state)) { - state.Current.CollectionEnumerator = enumerator; return false; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableDefaultConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableDefaultConverter.cs index b32db7d002efca..e53e48151a25be 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableDefaultConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/IEnumerableDefaultConverter.cs @@ -22,6 +22,7 @@ protected override bool OnWriteResume(Utf8JsonWriter writer, TCollection value, if (state.Current.CollectionEnumerator == null) { enumerator = value.GetEnumerator(); + state.Current.CollectionEnumerator = enumerator; if (!enumerator.MoveNext()) { enumerator.Dispose(); @@ -39,14 +40,12 @@ protected override bool OnWriteResume(Utf8JsonWriter writer, TCollection value, { if (ShouldFlush(ref state, writer)) { - state.Current.CollectionEnumerator = enumerator; return false; } TElement element = enumerator.Current; if (!converter.TryWrite(writer, element, options, ref state)) { - state.Current.CollectionEnumerator = enumerator; return false; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/StackOrQueueConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/StackOrQueueConverter.cs index aeca4e87a23c3f..7a09048f8e450f 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/StackOrQueueConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Collection/StackOrQueueConverter.cs @@ -46,6 +46,7 @@ protected sealed override bool OnWriteResume(Utf8JsonWriter writer, TCollection if (state.Current.CollectionEnumerator == null) { enumerator = value.GetEnumerator(); + state.Current.CollectionEnumerator = enumerator; if (!enumerator.MoveNext()) { return true; @@ -61,14 +62,12 @@ protected sealed override bool OnWriteResume(Utf8JsonWriter writer, TCollection { if (ShouldFlush(ref state, writer)) { - state.Current.CollectionEnumerator = enumerator; return false; } object? element = enumerator.Current; if (!converter.TryWrite(writer, element, options, ref state)) { - state.Current.CollectionEnumerator = enumerator; return false; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonConverterOfT.WriteCore.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonConverterOfT.WriteCore.cs index 8f1d8be16f89e0..1a7636bdc52529 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonConverterOfT.WriteCore.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/JsonConverterOfT.WriteCore.cs @@ -15,32 +15,37 @@ internal bool WriteCore( { return TryWrite(writer, value, options, ref state); } - catch (InvalidOperationException ex) when (ex.Source == ThrowHelper.ExceptionSourceValueToRethrowAsJsonException) + catch (Exception ex) { - ThrowHelper.ReThrowWithPath(ref state, ex); - throw; - } - catch (JsonException ex) when (ex.Path == null) - { - // JsonExceptions where the Path property is already set - // typically originate from nested calls to JsonSerializer; - // treat these cases as any other exception type and do not - // overwrite any exception information. + if (!state.SupportAsync) + { + // Async serializers should dispose sync and + // async disposables from the async root method. + state.DisposePendingDisposablesOnException(); + } - ThrowHelper.AddJsonExceptionInformation(ref state, ex); - throw; - } - catch (NotSupportedException ex) - { - // If the message already contains Path, just re-throw. This could occur in serializer re-entry cases. - // To get proper Path semantics in re-entry cases, APIs that take 'state' need to be used. - if (ex.Message.Contains(" Path: ")) + switch (ex) { - throw; + case InvalidOperationException when ex.Source == ThrowHelper.ExceptionSourceValueToRethrowAsJsonException: + ThrowHelper.ReThrowWithPath(ref state, ex); + break; + + case JsonException { Path: null } jsonException: + // JsonExceptions where the Path property is already set + // typically originate from nested calls to JsonSerializer; + // treat these cases as any other exception type and do not + // overwrite any exception information. + ThrowHelper.AddJsonExceptionInformation(ref state, jsonException); + break; + + case NotSupportedException when !ex.Message.Contains(" Path: "): + // If the message already contains Path, just re-throw. This could occur in serializer re-entry cases. + // To get proper Path semantics in re-entry cases, APIs that take 'state' need to be used. + ThrowHelper.ThrowNotSupportedException(ref state, ex); + break; } - ThrowHelper.ThrowNotSupportedException(ref state, ex); - return default; + throw; } } } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfoOfT.WriteHelpers.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfoOfT.WriteHelpers.cs index 9cbe0e7ba8d805..c14eeb7ef78ed7 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfoOfT.WriteHelpers.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfoOfT.WriteHelpers.cs @@ -305,6 +305,7 @@ rootValue is not null && { ThrowHelper.ThrowInvalidOperationException_PipeWriterDoesNotImplementUnflushedBytes(bufferWriter); } + state.PipeWriter = bufferWriter; state.FlushThreshold = (int)(bufferWriter.Capacity * JsonSerializer.FlushThreshold); diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs index 38517ffab7f01e..30151d67ad8c72 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/WriteStack.cs @@ -278,7 +278,7 @@ public void AddCompletedAsyncDisposable(IAsyncDisposable asyncDisposable) => (CompletedAsyncDisposables ??= new List()).Add(asyncDisposable); // Asynchronously dispose of any AsyncDisposables that have been scheduled for disposal - public async ValueTask DisposeCompletedAsyncDisposables() + public readonly async ValueTask DisposeCompletedAsyncDisposables() { Debug.Assert(CompletedAsyncDisposables?.Count > 0); Exception? exception = null; @@ -307,7 +307,7 @@ public async ValueTask DisposeCompletedAsyncDisposables() /// Walks the stack cleaning up any leftover IDisposables /// in the event of an exception on serialization /// - public void DisposePendingDisposablesOnException() + public readonly void DisposePendingDisposablesOnException() { Exception? exception = null; @@ -346,7 +346,7 @@ static void DisposeFrame(IEnumerator? collectionEnumerator, ref Exception? excep /// Walks the stack cleaning up any leftover I(Async)Disposables /// in the event of an exception on async serialization /// - public async ValueTask DisposePendingDisposablesOnExceptionAsync() + public readonly async ValueTask DisposePendingDisposablesOnExceptionAsync() { Exception? exception = null; diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs b/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs index 4a2759a57c3983..226e60c80c2828 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs @@ -586,9 +586,9 @@ public static void ThrowNotSupportedException(scoped ref ReadStack state, in Utf } [DoesNotReturn] - public static void ThrowNotSupportedException(ref WriteStack state, NotSupportedException ex) + public static void ThrowNotSupportedException(ref WriteStack state, Exception innerException) { - string message = ex.Message; + string message = innerException.Message; // The caller should check to ensure path is not already set. Debug.Assert(!message.Contains(" Path: ")); @@ -608,7 +608,7 @@ public static void ThrowNotSupportedException(ref WriteStack state, NotSupported message += $" Path: {state.PropertyPath()}."; - throw new NotSupportedException(message, ex); + throw new NotSupportedException(message, innerException); } [DoesNotReturn] diff --git a/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.Generic.Write.cs b/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.Generic.Write.cs index a6131347de3cdf..19afcd962c750b 100644 --- a/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.Generic.Write.cs +++ b/src/libraries/System.Text.Json/tests/Common/CollectionTests/CollectionTests.Generic.Write.cs @@ -955,6 +955,34 @@ public async Task WriteISetT_DisposesEnumerators() } } + [Fact] + public async Task WriteIEnumerableT_ElementSerializationThrows_DisposesEnumerators() + { + var items = new RefCountedList>(Enumerable.Repeat(ThrowingEnumerable(), 1)); + await Assert.ThrowsAsync(() => Serializer.SerializeWrapper(items.AsEnumerable())); + Assert.Equal(0, items.RefCount); + + static IEnumerable ThrowingEnumerable() + { + yield return 42; + throw new DivideByZeroException(); + } + } + + [Fact] + public async Task WriteIDictionaryT_ElementSerializationThrows_DisposesEnumerators() + { + var items = new RefCountedDictionary>(Enumerable.Repeat(new KeyValuePair>(42, ThrowingEnumerable()), 1)); + await Assert.ThrowsAsync(() => Serializer.SerializeWrapper((IDictionary>)items)); + Assert.Equal(0, items.RefCount); + + static IEnumerable ThrowingEnumerable() + { + yield return 42; + throw new DivideByZeroException(); + } + } + public class SimpleClassWithKeyValuePairs { public KeyValuePair KvpWStrVal { get; set; } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs index 7463dfbfbb6bca..79528a4a61749d 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs @@ -130,6 +130,7 @@ public async Task DeserializeAsyncEnumerable() [JsonSerializable(typeof(JsonElement))] [JsonSerializable(typeof(string))] [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(IDictionary>))] [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(Dictionary>))] @@ -556,6 +557,7 @@ public CollectionTests_Default() [JsonSerializable(typeof(JsonElement))] [JsonSerializable(typeof(string))] [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(IDictionary>))] [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(Dictionary>))]