From 0ee242591098ed4e78635134e89e7da7c82780fe Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Tue, 20 Jun 2023 16:51:07 -0600 Subject: [PATCH] Revert backward breaking change from v2.9 The new feature added in #536 that communicates channel failures to the remote party also caused channels to report failure if their owner `MultiplexingStream` was disposed of before the channel was. This broke several tests in the vs-servicehub repo and could theoretically break shipping code as well. In this change I hide this particular behavioral change behind a setting that requires opt-in. --- .../MultiplexingStream.Options.cs | 20 ++++++ src/Nerdbank.Streams/MultiplexingStream.cs | 9 ++- .../netstandard2.0/PublicAPI.Unshipped.txt | 2 + .../netstandard2.1/PublicAPI.Unshipped.txt | 2 + .../MultiplexingStreamTests.cs | 64 +++++++++++-------- 5 files changed, 71 insertions(+), 26 deletions(-) diff --git a/src/Nerdbank.Streams/MultiplexingStream.Options.cs b/src/Nerdbank.Streams/MultiplexingStream.Options.cs index 2494a666..b24e522a 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.Options.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.Options.cs @@ -60,6 +60,11 @@ public class Options /// private bool startSuspended; + /// + /// Backing field for the property. + /// + private bool faultOpenChannelsOnStreamDisposal; + /// /// Initializes a new instance of the class. /// @@ -83,6 +88,7 @@ public Options(Options copyFrom) this.defaultChannelTraceSourceFactory = copyFrom.defaultChannelTraceSourceFactory; this.defaultChannelTraceSourceFactoryWithQualifier = copyFrom.defaultChannelTraceSourceFactoryWithQualifier; this.startSuspended = copyFrom.startSuspended; + this.faultOpenChannelsOnStreamDisposal = copyFrom.faultOpenChannelsOnStreamDisposal; this.SeededChannels = copyFrom.SeededChannels.ToList(); } @@ -226,6 +232,20 @@ public bool StartSuspended /// public IList SeededChannels { get; private set; } + /// + /// Gets or sets a value indicating whether any open channels should be faulted (i.e. their task will be faulted) + /// when the is disposed. + /// + public bool FaultOpenChannelsOnStreamDisposal + { + get => this.faultOpenChannelsOnStreamDisposal; + set + { + this.ThrowIfFrozen(); + this.faultOpenChannelsOnStreamDisposal = value; + } + } + /// /// Gets a value indicating whether this instance is frozen. /// diff --git a/src/Nerdbank.Streams/MultiplexingStream.cs b/src/Nerdbank.Streams/MultiplexingStream.cs index 37b9473b..106549e9 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.cs @@ -104,6 +104,12 @@ public partial class MultiplexingStream : IDisposableObservable, System.IAsyncDi /// private readonly int protocolMajorVersion; + /// + /// A value indicating whether any open channels should be faulted (i.e. their task will be faulted) + /// when the is disposed. + /// + private readonly bool faultOpenChannelsOnStreamDisposal; + /// /// The last number assigned to a channel. /// Each use of this should increment by two, if has a value. @@ -131,6 +137,7 @@ private MultiplexingStream(Formatter formatter, bool? isOdd, Options options) } this.TraceSource = options.TraceSource; + this.faultOpenChannelsOnStreamDisposal = options.FaultOpenChannelsOnStreamDisposal; this.DefaultChannelTraceSourceFactory = options.DefaultChannelTraceSourceFactoryWithQualifier @@ -689,7 +696,7 @@ public async ValueTask DisposeAsync() { foreach (KeyValuePair entry in this.openChannels) { - entry.Value.Dispose(new ObjectDisposedException(nameof(MultiplexingStream))); + entry.Value.Dispose(this.faultOpenChannelsOnStreamDisposal ? new ObjectDisposedException(nameof(MultiplexingStream)) : null); } foreach (KeyValuePair>> entry in this.acceptingChannels) diff --git a/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt b/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt index b00dc10d..1343167c 100644 --- a/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt @@ -1,4 +1,6 @@ Nerdbank.Streams.BufferWriterExtensions +Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.get -> bool +Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.set -> void Nerdbank.Streams.ReadOnlySequenceExtensions Nerdbank.Streams.StreamPipeReader Nerdbank.Streams.StreamPipeReader.Read() -> System.IO.Pipelines.ReadResult diff --git a/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt b/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt index 31a266eb..e805ed90 100644 --- a/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt +++ b/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt @@ -1,4 +1,6 @@ Nerdbank.Streams.BufferWriterExtensions +Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.get -> bool +Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.set -> void Nerdbank.Streams.ReadOnlySequenceExtensions Nerdbank.Streams.StreamPipeReader Nerdbank.Streams.StreamPipeReader.Read() -> System.IO.Pipelines.ReadResult diff --git a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs index 101fc06b..213850af 100644 --- a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs +++ b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs @@ -38,28 +38,7 @@ public MultiplexingStreamTests(ITestOutputHelper logger) public async Task InitializeAsync() { - var mx1TraceSource = new TraceSource(nameof(this.mx1), SourceLevels.All); - var mx2TraceSource = new TraceSource(nameof(this.mx2), SourceLevels.All); - - mx1TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer)); - mx2TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer)); - - Func traceSourceFactory = (string mxInstanceName, MultiplexingStream.QualifiedChannelId id, string name) => - { - var traceSource = new TraceSource(mxInstanceName + " channel " + id, SourceLevels.All); - traceSource.Listeners.Clear(); // remove DefaultTraceListener - traceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer)); - return traceSource; - }; - - Func mx1TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx1), id, name); - Func mx2TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx2), id, name); - - (this.transport1, this.transport2) = FullDuplexStream.CreatePair(new PipeOptions(pauseWriterThreshold: 2 * 1024 * 1024)); - Task? mx1 = MultiplexingStream.CreateAsync(this.transport1, new MultiplexingStream.Options { ProtocolMajorVersion = this.ProtocolMajorVersion, TraceSource = mx1TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx1TraceSourceFactory }, this.TimeoutToken); - Task? mx2 = MultiplexingStream.CreateAsync(this.transport2, new MultiplexingStream.Options { ProtocolMajorVersion = this.ProtocolMajorVersion, TraceSource = mx2TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx2TraceSourceFactory }, this.TimeoutToken); - this.mx1 = await mx1; - this.mx2 = await mx2; + await this.ReinitializeMxStreamsAsync(new MultiplexingStream.Options()); } public async Task DisposeAsync() @@ -302,13 +281,14 @@ public async Task Disposal_DisposesTransportStream() Assert.Throws(() => this.transport1.Position); } - [Fact] - public async Task Dispose_DisposesChannels() + [Theory, PairwiseData] + public async Task Dispose_DisposesChannels(bool channelFaulted) { + await this.ReinitializeMxStreamsAsync(new MultiplexingStream.Options() { FaultOpenChannelsOnStreamDisposal = channelFaulted }); (MultiplexingStream.Channel channel1, MultiplexingStream.Channel channel2) = await this.EstablishChannelsAsync("A"); await this.mx1.DisposeAsync(); Assert.True(channel1.IsDisposed); - await VerifyChannelCompleted(channel1, new ObjectDisposedException(nameof(MultiplexingStream)).Message); + await VerifyChannelCompleted(channel1, channelFaulted ? new ObjectDisposedException(nameof(MultiplexingStream)).Message : null); #pragma warning disable CS0618 // Type or member is obsolete await channel1.Input.WaitForWriterCompletionAsync().WithCancellation(this.TimeoutToken); @@ -1322,6 +1302,40 @@ protected async Task WaitForEphemeralChannelOfferToPropagateAsync() return (channel1.AsStream(), channel2.AsStream()); } + private async Task ReinitializeMxStreamsAsync(MultiplexingStream.Options optionsTemplate) + { + await (this.mx1?.DisposeAsync() ?? default); + await (this.mx2?.DisposeAsync() ?? default); + + var mx1TraceSource = new TraceSource(nameof(this.mx1), SourceLevels.All); + var mx2TraceSource = new TraceSource(nameof(this.mx2), SourceLevels.All); + + mx1TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer)); + mx2TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer)); + + Func traceSourceFactory = (string mxInstanceName, MultiplexingStream.QualifiedChannelId id, string name) => + { + var traceSource = new TraceSource(mxInstanceName + " channel " + id, SourceLevels.All); + traceSource.Listeners.Clear(); // remove DefaultTraceListener + traceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer)); + return traceSource; + }; + + Func mx1TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx1), id, name); + Func mx2TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx2), id, name); + + optionsTemplate = new(optionsTemplate) { ProtocolMajorVersion = this.ProtocolMajorVersion }; + + var mx1Options = new MultiplexingStream.Options(optionsTemplate) { TraceSource = mx1TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx1TraceSourceFactory }; + var mx2Options = new MultiplexingStream.Options(optionsTemplate) { TraceSource = mx2TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx2TraceSourceFactory }; + + (this.transport1, this.transport2) = FullDuplexStream.CreatePair(new PipeOptions(pauseWriterThreshold: 2 * 1024 * 1024)); + Task? mx1 = MultiplexingStream.CreateAsync(this.transport1, mx1Options, this.TimeoutToken); + Task? mx2 = MultiplexingStream.CreateAsync(this.transport2, mx2Options, this.TimeoutToken); + this.mx1 = await mx1; + this.mx2 = await mx2; + } + protected class SlowPipeWriter : PipeWriter { private readonly Sequence writtenBytes = new Sequence();