diff --git a/src/Nerdbank.Streams/MultiplexingStream.Options.cs b/src/Nerdbank.Streams/MultiplexingStream.Options.cs
index 2494a666..b24e522a 100644
--- a/src/Nerdbank.Streams/MultiplexingStream.Options.cs
+++ b/src/Nerdbank.Streams/MultiplexingStream.Options.cs
@@ -60,6 +60,11 @@ public class Options
///
private bool startSuspended;
+ ///
+ /// Backing field for the property.
+ ///
+ private bool faultOpenChannelsOnStreamDisposal;
+
///
/// Initializes a new instance of the class.
///
@@ -83,6 +88,7 @@ public Options(Options copyFrom)
this.defaultChannelTraceSourceFactory = copyFrom.defaultChannelTraceSourceFactory;
this.defaultChannelTraceSourceFactoryWithQualifier = copyFrom.defaultChannelTraceSourceFactoryWithQualifier;
this.startSuspended = copyFrom.startSuspended;
+ this.faultOpenChannelsOnStreamDisposal = copyFrom.faultOpenChannelsOnStreamDisposal;
this.SeededChannels = copyFrom.SeededChannels.ToList();
}
@@ -226,6 +232,20 @@ public bool StartSuspended
///
public IList SeededChannels { get; private set; }
+ ///
+ /// Gets or sets a value indicating whether any open channels should be faulted (i.e. their task will be faulted)
+ /// when the is disposed.
+ ///
+ public bool FaultOpenChannelsOnStreamDisposal
+ {
+ get => this.faultOpenChannelsOnStreamDisposal;
+ set
+ {
+ this.ThrowIfFrozen();
+ this.faultOpenChannelsOnStreamDisposal = value;
+ }
+ }
+
///
/// Gets a value indicating whether this instance is frozen.
///
diff --git a/src/Nerdbank.Streams/MultiplexingStream.cs b/src/Nerdbank.Streams/MultiplexingStream.cs
index 37b9473b..106549e9 100644
--- a/src/Nerdbank.Streams/MultiplexingStream.cs
+++ b/src/Nerdbank.Streams/MultiplexingStream.cs
@@ -104,6 +104,12 @@ public partial class MultiplexingStream : IDisposableObservable, System.IAsyncDi
///
private readonly int protocolMajorVersion;
+ ///
+ /// A value indicating whether any open channels should be faulted (i.e. their task will be faulted)
+ /// when the is disposed.
+ ///
+ private readonly bool faultOpenChannelsOnStreamDisposal;
+
///
/// The last number assigned to a channel.
/// Each use of this should increment by two, if has a value.
@@ -131,6 +137,7 @@ private MultiplexingStream(Formatter formatter, bool? isOdd, Options options)
}
this.TraceSource = options.TraceSource;
+ this.faultOpenChannelsOnStreamDisposal = options.FaultOpenChannelsOnStreamDisposal;
this.DefaultChannelTraceSourceFactory =
options.DefaultChannelTraceSourceFactoryWithQualifier
@@ -689,7 +696,7 @@ public async ValueTask DisposeAsync()
{
foreach (KeyValuePair entry in this.openChannels)
{
- entry.Value.Dispose(new ObjectDisposedException(nameof(MultiplexingStream)));
+ entry.Value.Dispose(this.faultOpenChannelsOnStreamDisposal ? new ObjectDisposedException(nameof(MultiplexingStream)) : null);
}
foreach (KeyValuePair>> entry in this.acceptingChannels)
diff --git a/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt b/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt
index b00dc10d..1343167c 100644
--- a/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt
+++ b/src/Nerdbank.Streams/netstandard2.0/PublicAPI.Unshipped.txt
@@ -1,4 +1,6 @@
Nerdbank.Streams.BufferWriterExtensions
+Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.get -> bool
+Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.set -> void
Nerdbank.Streams.ReadOnlySequenceExtensions
Nerdbank.Streams.StreamPipeReader
Nerdbank.Streams.StreamPipeReader.Read() -> System.IO.Pipelines.ReadResult
diff --git a/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt b/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt
index 31a266eb..e805ed90 100644
--- a/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt
+++ b/src/Nerdbank.Streams/netstandard2.1/PublicAPI.Unshipped.txt
@@ -1,4 +1,6 @@
Nerdbank.Streams.BufferWriterExtensions
+Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.get -> bool
+Nerdbank.Streams.MultiplexingStream.Options.FaultOpenChannelsOnStreamDisposal.set -> void
Nerdbank.Streams.ReadOnlySequenceExtensions
Nerdbank.Streams.StreamPipeReader
Nerdbank.Streams.StreamPipeReader.Read() -> System.IO.Pipelines.ReadResult
diff --git a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs
index 101fc06b..213850af 100644
--- a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs
+++ b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs
@@ -38,28 +38,7 @@ public MultiplexingStreamTests(ITestOutputHelper logger)
public async Task InitializeAsync()
{
- var mx1TraceSource = new TraceSource(nameof(this.mx1), SourceLevels.All);
- var mx2TraceSource = new TraceSource(nameof(this.mx2), SourceLevels.All);
-
- mx1TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer));
- mx2TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer));
-
- Func traceSourceFactory = (string mxInstanceName, MultiplexingStream.QualifiedChannelId id, string name) =>
- {
- var traceSource = new TraceSource(mxInstanceName + " channel " + id, SourceLevels.All);
- traceSource.Listeners.Clear(); // remove DefaultTraceListener
- traceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer));
- return traceSource;
- };
-
- Func mx1TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx1), id, name);
- Func mx2TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx2), id, name);
-
- (this.transport1, this.transport2) = FullDuplexStream.CreatePair(new PipeOptions(pauseWriterThreshold: 2 * 1024 * 1024));
- Task? mx1 = MultiplexingStream.CreateAsync(this.transport1, new MultiplexingStream.Options { ProtocolMajorVersion = this.ProtocolMajorVersion, TraceSource = mx1TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx1TraceSourceFactory }, this.TimeoutToken);
- Task? mx2 = MultiplexingStream.CreateAsync(this.transport2, new MultiplexingStream.Options { ProtocolMajorVersion = this.ProtocolMajorVersion, TraceSource = mx2TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx2TraceSourceFactory }, this.TimeoutToken);
- this.mx1 = await mx1;
- this.mx2 = await mx2;
+ await this.ReinitializeMxStreamsAsync(new MultiplexingStream.Options());
}
public async Task DisposeAsync()
@@ -302,13 +281,14 @@ public async Task Disposal_DisposesTransportStream()
Assert.Throws(() => this.transport1.Position);
}
- [Fact]
- public async Task Dispose_DisposesChannels()
+ [Theory, PairwiseData]
+ public async Task Dispose_DisposesChannels(bool channelFaulted)
{
+ await this.ReinitializeMxStreamsAsync(new MultiplexingStream.Options() { FaultOpenChannelsOnStreamDisposal = channelFaulted });
(MultiplexingStream.Channel channel1, MultiplexingStream.Channel channel2) = await this.EstablishChannelsAsync("A");
await this.mx1.DisposeAsync();
Assert.True(channel1.IsDisposed);
- await VerifyChannelCompleted(channel1, new ObjectDisposedException(nameof(MultiplexingStream)).Message);
+ await VerifyChannelCompleted(channel1, channelFaulted ? new ObjectDisposedException(nameof(MultiplexingStream)).Message : null);
#pragma warning disable CS0618 // Type or member is obsolete
await channel1.Input.WaitForWriterCompletionAsync().WithCancellation(this.TimeoutToken);
@@ -1322,6 +1302,40 @@ protected async Task WaitForEphemeralChannelOfferToPropagateAsync()
return (channel1.AsStream(), channel2.AsStream());
}
+ private async Task ReinitializeMxStreamsAsync(MultiplexingStream.Options optionsTemplate)
+ {
+ await (this.mx1?.DisposeAsync() ?? default);
+ await (this.mx2?.DisposeAsync() ?? default);
+
+ var mx1TraceSource = new TraceSource(nameof(this.mx1), SourceLevels.All);
+ var mx2TraceSource = new TraceSource(nameof(this.mx2), SourceLevels.All);
+
+ mx1TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer));
+ mx2TraceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer));
+
+ Func traceSourceFactory = (string mxInstanceName, MultiplexingStream.QualifiedChannelId id, string name) =>
+ {
+ var traceSource = new TraceSource(mxInstanceName + " channel " + id, SourceLevels.All);
+ traceSource.Listeners.Clear(); // remove DefaultTraceListener
+ traceSource.Listeners.Add(new XunitTraceListener(this.Logger, this.TestId, this.TestTimer));
+ return traceSource;
+ };
+
+ Func mx1TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx1), id, name);
+ Func mx2TraceSourceFactory = (MultiplexingStream.QualifiedChannelId id, string name) => traceSourceFactory(nameof(this.mx2), id, name);
+
+ optionsTemplate = new(optionsTemplate) { ProtocolMajorVersion = this.ProtocolMajorVersion };
+
+ var mx1Options = new MultiplexingStream.Options(optionsTemplate) { TraceSource = mx1TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx1TraceSourceFactory };
+ var mx2Options = new MultiplexingStream.Options(optionsTemplate) { TraceSource = mx2TraceSource, DefaultChannelTraceSourceFactoryWithQualifier = mx2TraceSourceFactory };
+
+ (this.transport1, this.transport2) = FullDuplexStream.CreatePair(new PipeOptions(pauseWriterThreshold: 2 * 1024 * 1024));
+ Task? mx1 = MultiplexingStream.CreateAsync(this.transport1, mx1Options, this.TimeoutToken);
+ Task? mx2 = MultiplexingStream.CreateAsync(this.transport2, mx2Options, this.TimeoutToken);
+ this.mx1 = await mx1;
+ this.mx2 = await mx2;
+ }
+
protected class SlowPipeWriter : PipeWriter
{
private readonly Sequence writtenBytes = new Sequence();