From 51b308473903cae60f219828eedfac9dae130be7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=C3=A1n=20J=C3=B6kull=20Sigur=C3=B0arson?= Date: Mon, 11 Apr 2022 09:54:01 +0000 Subject: [PATCH 1/4] Adding pipelines --- projects/Benchmarks/ArrayBufferWriter.cs | 171 ++++++++++++ .../WireFormatting/MethodFraming.cs | 36 ++- .../RabbitMQ.Client/RabbitMQ.Client.csproj | 4 +- .../client/api/ConnectionFactoryBase.cs | 11 +- .../RabbitMQ.Client/client/framing/Model.cs | 216 +++++++-------- .../client/framing/Protocol.cs | 4 + .../client/impl/Connection.Commands.cs | 3 +- .../client/impl/Connection.Heartbeat.cs | 2 +- .../client/impl/Connection.Receive.cs | 138 ++++++---- .../RabbitMQ.Client/client/impl/Connection.cs | 97 ++++++- projects/RabbitMQ.Client/client/impl/Frame.cs | 159 +++-------- .../client/impl/IFrameHandler.cs | 11 +- .../client/impl/ProtocolBase.cs | 1 + .../client/impl/RecoveryAwareModel.cs | 1 - .../client/impl/SessionBase.cs | 4 +- .../client/impl/SocketFrameHandler.cs | 246 +++++------------- .../client/impl/TcpClientAdapterHelper.cs | 23 +- .../util/NetworkOrderDeserializer.cs | 35 +++ projects/Unit/APIApproval.cs | 1 - projects/Unit/Fixtures.cs | 9 +- projects/Unit/RabbitMQCtl.cs | 15 +- projects/Unit/TestAsyncConsumer.cs | 45 ++-- projects/Unit/TestBasicPublish.cs | 2 +- projects/Unit/TestFrameFormatting.cs | 33 +-- projects/Unit/TestPublisherConfirms.cs | 32 +-- 25 files changed, 733 insertions(+), 566 deletions(-) create mode 100644 projects/Benchmarks/ArrayBufferWriter.cs diff --git a/projects/Benchmarks/ArrayBufferWriter.cs b/projects/Benchmarks/ArrayBufferWriter.cs new file mode 100644 index 0000000000..e822af8fa4 --- /dev/null +++ b/projects/Benchmarks/ArrayBufferWriter.cs @@ -0,0 +1,171 @@ +// We only need this if we aren't targeting .NET 6.0 or greater since it already exists there +#if !NET6_0_OR_GREATER +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Buffers +{ + public class ArrayBufferWriter : IBufferWriter, IDisposable + { + private T[] _rentedBuffer; + private int _written; + private long _committed; + + private const int MinimumBufferSize = 256; + + public ArrayBufferWriter(int initialCapacity = MinimumBufferSize) + { + if (initialCapacity <= 0) + { + throw new ArgumentException(null, nameof(initialCapacity)); + } + + _rentedBuffer = ArrayPool.Shared.Rent(initialCapacity); + _written = 0; + _committed = 0; + } + + public Memory WrittenMemory + { + get + { + CheckIfDisposed(); + + return _rentedBuffer.AsMemory(0, _written); + } + } + + public Span WrittenSpan + { + get + { + CheckIfDisposed(); + + return _rentedBuffer.AsSpan(0, _written); + } + } + + public int BytesWritten + { + get + { + CheckIfDisposed(); + + return _written; + } + } + + public long BytesCommitted + { + get + { + CheckIfDisposed(); + + return _committed; + } + } + + public void Clear() + { + CheckIfDisposed(); + + ClearHelper(); + } + + private void ClearHelper() + { + _rentedBuffer.AsSpan(0, _written).Clear(); + _written = 0; + } + + public void Advance(int count) + { + CheckIfDisposed(); + + if (count < 0) + throw new ArgumentException(nameof(count)); + + if (_written > _rentedBuffer.Length - count) + throw new InvalidOperationException("Cannot advance past the end of the buffer."); + + _written += count; + } + + // Returns the rented buffer back to the pool + public void Dispose() + { + if (_rentedBuffer == null) + { + return; + } + + ArrayPool.Shared.Return(_rentedBuffer, clearArray: true); + _rentedBuffer = null; + _written = 0; + } + + private void CheckIfDisposed() + { + if (_rentedBuffer == null) + throw new ObjectDisposedException(nameof(ArrayBufferWriter)); + } + + public Memory GetMemory(int sizeHint = 0) + { + CheckIfDisposed(); + + if (sizeHint < 0) + throw new ArgumentException(nameof(sizeHint)); + + CheckAndResizeBuffer(sizeHint); + return _rentedBuffer.AsMemory(_written); + } + + public Span GetSpan(int sizeHint = 0) + { + CheckIfDisposed(); + + if (sizeHint < 0) + throw new ArgumentException(nameof(sizeHint)); + + CheckAndResizeBuffer(sizeHint); + return _rentedBuffer.AsSpan(_written); + } + + private void CheckAndResizeBuffer(int sizeHint) + { + Debug.Assert(sizeHint >= 0); + + if (sizeHint == 0) + { + sizeHint = MinimumBufferSize; + } + + int availableSpace = _rentedBuffer.Length - _written; + + if (sizeHint > availableSpace) + { + int growBy = sizeHint > _rentedBuffer.Length ? sizeHint : _rentedBuffer.Length; + + int newSize = checked(_rentedBuffer.Length + growBy); + + T[] oldBuffer = _rentedBuffer; + + _rentedBuffer = ArrayPool.Shared.Rent(newSize); + + Debug.Assert(oldBuffer.Length >= _written); + Debug.Assert(_rentedBuffer.Length >= _written); + + oldBuffer.AsSpan(0, _written).CopyTo(_rentedBuffer); + ArrayPool.Shared.Return(oldBuffer, clearArray: true); + } + + Debug.Assert(_rentedBuffer.Length - _written > 0); + Debug.Assert(_rentedBuffer.Length - _written >= sizeHint); + } + } +} +#endif diff --git a/projects/Benchmarks/WireFormatting/MethodFraming.cs b/projects/Benchmarks/WireFormatting/MethodFraming.cs index e2f032341e..b13569500a 100644 --- a/projects/Benchmarks/WireFormatting/MethodFraming.cs +++ b/projects/Benchmarks/WireFormatting/MethodFraming.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Text; using BenchmarkDotNet.Attributes; @@ -19,7 +20,12 @@ public class MethodFramingBasicAck public ushort Channel { get; set; } [Benchmark] - public ReadOnlyMemory BasicAckWrite() => Framing.SerializeToFrames(ref _basicAck, Channel); + public ReadOnlyMemory BasicAckWrite() + { + ArrayBufferWriter _writer = new ArrayBufferWriter(); + Framing.SerializeToFrames(ref _basicAck, _writer, Channel); + return _writer.WrittenMemory; + } } [Config(typeof(Config))] @@ -41,13 +47,28 @@ public class MethodFramingBasicPublish public int FrameMax { get; set; } [Benchmark] - public ReadOnlyMemory BasicPublishWriteNonEmpty() => Framing.SerializeToFrames(ref _basicPublish, ref _properties, _body, Channel, FrameMax); + public ReadOnlyMemory BasicPublishWriteNonEmpty() + { + ArrayBufferWriter _writer = new ArrayBufferWriter(); + Framing.SerializeToFrames(ref _basicPublish, ref _properties, _body, _writer, Channel, FrameMax); + return _writer.WrittenMemory; + } [Benchmark] - public ReadOnlyMemory BasicPublishWrite() => Framing.SerializeToFrames(ref _basicPublish, ref _propertiesEmpty, _bodyEmpty, Channel, FrameMax); + public ReadOnlyMemory BasicPublishWrite() + { + ArrayBufferWriter _writer = new ArrayBufferWriter(); + Framing.SerializeToFrames(ref _basicPublish, ref _propertiesEmpty, _bodyEmpty, _writer, Channel, FrameMax); + return _writer.WrittenMemory; + } [Benchmark] - public ReadOnlyMemory BasicPublishMemoryWrite() => Framing.SerializeToFrames(ref _basicPublishMemory, ref _propertiesEmpty, _bodyEmpty, Channel, FrameMax); + public ReadOnlyMemory BasicPublishMemoryWrite() + { + ArrayBufferWriter _writer = new ArrayBufferWriter(); + Framing.SerializeToFrames(ref _basicPublishMemory, ref _propertiesEmpty, _bodyEmpty, _writer, Channel, FrameMax); + return _writer.WrittenMemory; + } } [Config(typeof(Config))] @@ -60,6 +81,11 @@ public class MethodFramingChannelClose public ushort Channel { get; set; } [Benchmark] - public ReadOnlyMemory ChannelCloseWrite() => Framing.SerializeToFrames(ref _channelClose, Channel); + public ReadOnlyMemory ChannelCloseWrite() + { + ArrayBufferWriter _writer = new ArrayBufferWriter(); + Framing.SerializeToFrames(ref _channelClose, _writer, Channel); + return _writer.WrittenMemory; + } } } diff --git a/projects/RabbitMQ.Client/RabbitMQ.Client.csproj b/projects/RabbitMQ.Client/RabbitMQ.Client.csproj index 4365900ea6..838cc7631e 100644 --- a/projects/RabbitMQ.Client/RabbitMQ.Client.csproj +++ b/projects/RabbitMQ.Client/RabbitMQ.Client.csproj @@ -63,6 +63,6 @@ + - - + \ No newline at end of file diff --git a/projects/RabbitMQ.Client/client/api/ConnectionFactoryBase.cs b/projects/RabbitMQ.Client/client/api/ConnectionFactoryBase.cs index 4964b8f867..c224257b20 100644 --- a/projects/RabbitMQ.Client/client/api/ConnectionFactoryBase.cs +++ b/projects/RabbitMQ.Client/client/api/ConnectionFactoryBase.cs @@ -31,6 +31,9 @@ using System; using System.Net.Sockets; + +using Pipelines.Sockets.Unofficial; + using RabbitMQ.Client.Impl; namespace RabbitMQ.Client @@ -49,12 +52,8 @@ public class ConnectionFactoryBase /// New instance of a . public static ITcpClient DefaultSocketFactory(AddressFamily addressFamily) { - var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp) - { - NoDelay = true, - ReceiveBufferSize = 65536, - SendBufferSize = 65536 - }; + var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp); + SocketConnection.SetRecommendedClientOptions(socket); return new TcpClientAdapter(socket); } } diff --git a/projects/RabbitMQ.Client/client/framing/Model.cs b/projects/RabbitMQ.Client/client/framing/Model.cs index 964b4cc0c2..1d4de9020b 100644 --- a/projects/RabbitMQ.Client/client/framing/Model.cs +++ b/projects/RabbitMQ.Client/client/framing/Model.cs @@ -30,8 +30,8 @@ //--------------------------------------------------------------------------- using System.Collections.Generic; + using RabbitMQ.Client.client.framing; -using RabbitMQ.Client.client.impl; using RabbitMQ.Client.Impl; namespace RabbitMQ.Client.Framing.Impl @@ -297,113 +297,115 @@ public override void TxSelect() protected override bool DispatchAsynchronous(in IncomingCommand cmd) { - switch (cmd.CommandId) + int classType = (int)cmd.CommandId >> 16; + int methodType = (int)cmd.CommandId & 0x00FF; + return classType switch + { + ClassConstants.Basic => DispatchBasicCommand(in cmd, methodType), + ClassConstants.Channel => DispatchChannelCommand(in cmd, methodType), + ClassConstants.Connection => DispatchConnectionCommand(in cmd, methodType), + ClassConstants.Queue => DispatchQueueCommand(in cmd, methodType), + _ => false, + }; + } + + private bool DispatchQueueCommand(in IncomingCommand cmd, int methodType) + { + switch (methodType) + { + case QueueMethodConstants.DeclareOk: + HandleQueueDeclareOk(in cmd); + return true; + default: + return false; + } + } + + private bool DispatchConnectionCommand(in IncomingCommand cmd, int methodType) + { + switch (methodType) + { + case ConnectionMethodConstants.Start: + HandleConnectionStart(in cmd); + return true; + case ConnectionMethodConstants.Secure: + HandleConnectionSecure(in cmd); + return true; + case ConnectionMethodConstants.Tune: + HandleConnectionTune(in cmd); + return true; + case ConnectionMethodConstants.Close: + HandleConnectionClose(in cmd); + return true; + case ConnectionMethodConstants.Blocked: + HandleConnectionBlocked(in cmd); + return true; + case ConnectionMethodConstants.Unblocked: + cmd.ReturnMethodBuffer(); + HandleConnectionUnblocked(); + return true; + default: + return false; + } + } + + private bool DispatchChannelCommand(in IncomingCommand cmd, int methodType) + { + switch (methodType) + { + case ChannelMethodConstants.Flow: + HandleChannelFlow(in cmd); + return true; + case ChannelMethodConstants.Close: + HandleChannelClose(in cmd); + return true; + case ChannelMethodConstants.CloseOk: + cmd.ReturnMethodBuffer(); + HandleChannelCloseOk(); + return true; + default: + return false; + } + } + + private bool DispatchBasicCommand(in IncomingCommand cmd, int methodType) + { + switch (methodType) { - case ProtocolCommandId.BasicDeliver: - { - HandleBasicDeliver(in cmd); - return true; - } - case ProtocolCommandId.BasicAck: - { - HandleBasicAck(in cmd); - return true; - } - case ProtocolCommandId.BasicCancel: - { - HandleBasicCancel(in cmd); - return true; - } - case ProtocolCommandId.BasicCancelOk: - { - HandleBasicCancelOk(in cmd); - return true; - } - case ProtocolCommandId.BasicConsumeOk: - { - HandleBasicConsumeOk(in cmd); - return true; - } - case ProtocolCommandId.BasicGetEmpty: - { - cmd.ReturnMethodBuffer(); - HandleBasicGetEmpty(); - return true; - } - case ProtocolCommandId.BasicGetOk: - { - HandleBasicGetOk(in cmd); - return true; - } - case ProtocolCommandId.BasicNack: - { - HandleBasicNack(in cmd); - return true; - } - case ProtocolCommandId.BasicRecoverOk: - { - cmd.ReturnMethodBuffer(); - HandleBasicRecoverOk(); - return true; - } - case ProtocolCommandId.BasicReturn: - { - HandleBasicReturn(in cmd); - return true; - } - case ProtocolCommandId.ChannelClose: - { - HandleChannelClose(in cmd); - return true; - } - case ProtocolCommandId.ChannelCloseOk: - { - cmd.ReturnMethodBuffer(); - HandleChannelCloseOk(); - return true; - } - case ProtocolCommandId.ChannelFlow: - { - HandleChannelFlow(in cmd); - return true; - } - case ProtocolCommandId.ConnectionBlocked: - { - HandleConnectionBlocked(in cmd); - return true; - } - case ProtocolCommandId.ConnectionClose: - { - HandleConnectionClose(in cmd); - return true; - } - case ProtocolCommandId.ConnectionSecure: - { - HandleConnectionSecure(in cmd); - return true; - } - case ProtocolCommandId.ConnectionStart: - { - HandleConnectionStart(in cmd); - return true; - } - case ProtocolCommandId.ConnectionTune: - { - HandleConnectionTune(in cmd); - return true; - } - case ProtocolCommandId.ConnectionUnblocked: - { - cmd.ReturnMethodBuffer(); - HandleConnectionUnblocked(); - return true; - } - case ProtocolCommandId.QueueDeclareOk: - { - HandleQueueDeclareOk(in cmd); - return true; - } - default: return false; + case BasicMethodConstants.ConsumeOk: + HandleBasicConsumeOk(in cmd); + return true; + case BasicMethodConstants.Cancel: + HandleBasicCancel(in cmd); + return true; + case BasicMethodConstants.CancelOk: + HandleBasicCancelOk(in cmd); + return true; + case BasicMethodConstants.Return: + HandleBasicReturn(in cmd); + return true; + case BasicMethodConstants.Deliver: + HandleBasicDeliver(in cmd); + return true; + case BasicMethodConstants.GetOk: + HandleBasicGetOk(in cmd); + return true; + case BasicMethodConstants.GetEmpty: + cmd.ReturnMethodBuffer(); + HandleBasicGetEmpty(); + return true; + case BasicMethodConstants.Ack: + HandleBasicAck(in cmd); + return true; + case BasicMethodConstants.RecoverOk: + cmd.ReturnMethodBuffer(); + HandleBasicRecoverOk(); + return true; + case BasicMethodConstants.Nack: + HandleBasicNack(in cmd); + return true; + default: + return false; } } } diff --git a/projects/RabbitMQ.Client/client/framing/Protocol.cs b/projects/RabbitMQ.Client/client/framing/Protocol.cs index 9f4a8d900e..d96ee5cbb6 100644 --- a/projects/RabbitMQ.Client/client/framing/Protocol.cs +++ b/projects/RabbitMQ.Client/client/framing/Protocol.cs @@ -50,6 +50,10 @@ internal sealed class Protocol : ProtocolBase ///Protocol API name (= :AMQP_0_9_1) public override string ApiName => ":AMQP_0_9_1"; + public override ReadOnlySpan Header => Amqp091Header; + + private static ReadOnlySpan Amqp091Header => new byte[] { (byte)'A', (byte)'M', (byte)'Q', (byte)'P', 0, 0, 9, 1 }; + ///Default TCP port (= 5672) public override int DefaultPort => 5672; diff --git a/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs b/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs index 7c53fe579b..7158231e31 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs @@ -32,6 +32,7 @@ using System; using System.IO; using System.Text; + using RabbitMQ.Client.Events; using RabbitMQ.Client.Exceptions; using RabbitMQ.Client.Impl; @@ -83,7 +84,7 @@ private void StartAndTune() _model0.m_connectionStartCell = connectionStartCell; _model0.HandshakeContinuationTimeout = _config.HandshakeContinuationTimeout; _frameHandler.ReadTimeout = _config.HandshakeContinuationTimeout; - _frameHandler.SendHeader(); + Write(Protocol.Header); ConnectionStartDetails connectionStart = connectionStartCell.WaitForValue(); diff --git a/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs b/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs index bd3ae7d7dc..a374506ad2 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs @@ -149,7 +149,7 @@ private void HeartbeatWriteTimerCallback(object? state) { if (!_closed) { - Write(Client.Impl.Framing.Heartbeat.GetHeartbeatFrame()); + Write(Client.Impl.Framing.Heartbeat.Payload); _heartbeatWriteTimer?.Change((int)_heartbeatTimeSpan.TotalMilliseconds, Timeout.Infinite); } } diff --git a/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs b/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs index 38274cea51..849931dce1 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs @@ -30,7 +30,10 @@ //--------------------------------------------------------------------------- using System; +using System.Buffers; +using System.Diagnostics; using System.IO; +using System.IO.Pipelines; using System.Threading.Tasks; using RabbitMQ.Client.Exceptions; using RabbitMQ.Client.Impl; @@ -43,11 +46,11 @@ internal sealed partial class Connection private readonly IFrameHandler _frameHandler; private readonly Task _mainLoopTask; - private void MainLoop() + private async Task MainLoop() { try { - ReceiveLoop(); + await MainLoopIteration().ConfigureAwait(false); } catch (EndOfStreamException eose) { @@ -56,7 +59,7 @@ private void MainLoop() } catch (HardProtocolException hpe) { - HardProtocolExceptionHandler(hpe); + await HardProtocolExceptionHandler(hpe); } catch (Exception ex) { @@ -66,57 +69,104 @@ private void MainLoop() FinishClose(); } - private void ReceiveLoop() + private async ValueTask MainLoopIteration() { - while (!_closed) + try { - InboundFrame frame = _frameHandler.ReadFrame(); - NotifyHeartbeatListener(); - - bool shallReturn = true; - if (frame.Channel == 0) + while (!_closed) { - if (frame.Type == FrameType.FrameHeartbeat) + Debug.WriteLine("Trying to read synchronously from pipe."); + if (!_frameHandler.FrameReader.TryRead(out ReadResult result)) { - // Ignore it: we've already just reset the heartbeat + Debug.WriteLine("Failed to read synchronously from pipe, going async..."); + result = await _frameHandler.FrameReader.ReadAsync().ConfigureAwait(false); } - else + + ReadOnlySequence buffer = result.Buffer; + Debug.WriteLine("Read {0:N0} bytes from pipe.", result.Buffer.Length); + + try { - // In theory, we could get non-connection.close-ok - // frames here while we're quiescing (m_closeReason != - // null). In practice, there's a limited number of - // things the server can ask of us on channel 0 - - // essentially, just connection.close. That, combined - // with the restrictions on pipelining, mean that - // we're OK here to handle channel 0 traffic in a - // quiescing situation, even though technically we - // should be ignoring everything except - // connection.close-ok. - shallReturn = _session0.HandleFrame(in frame); + // If we canceled or we are empty + if (buffer.IsEmpty) + { + throw new EndOfStreamException("Reached the end of the stream. Possible authentication failure."); + } + + if (buffer.First.Span[0] == 'A') + { + if (buffer.Length >= 8) + { + InboundFrame.ProcessProtocolHeader(buffer); + } + + throw new EndOfStreamException("Invalid/truncated protocol header."); + } + + int framesRead = 0; + while (7 < (uint)buffer.Length && InboundFrame.TryParseInboundFrame(ref buffer, out InboundFrame frame)) + { + framesRead++; + HandleBytes(in frame); + } + + Debug.WriteLine("Read {0:N0} frames from pipe. Remaining buffer size is {0:N0}.", framesRead, buffer.Length); } - } - else - { - // If we're still m_running, but have a m_closeReason, - // then we must be quiescing, which means any inbound - // frames for non-zero channels (and any inbound - // commands on channel zero that aren't - // Connection.CloseOk) must be discarded. - if (_closeReason is null) + finally + { + _frameHandler.FrameReader.AdvanceTo(buffer.Start, buffer.End); + } + + // We won't be receiving more data + if (result.IsCompleted) { - // No close reason, not quiescing the - // connection. Handle the frame. (Of course, the - // Session itself may be quiescing this particular - // channel, but that's none of our concern.) - shallReturn = _sessionManager.Lookup(frame.Channel).HandleFrame(in frame); + throw new EndOfStreamException("Reached the end of the stream. Possible authentication failure."); } } + } + finally + { + await _frameHandler.FrameReader.CompleteAsync(); + } + } - if (shallReturn) + private void HandleBytes(in InboundFrame frame) + { + NotifyHeartbeatListener(); + + // Nothing to do if this is a heartbeat. + if (frame.Type != FrameType.FrameHeartbeat) + { + // In theory, we could get non-connection.close-ok + // frames here while we're quiescing (m_closeReason != + // null). In practice, there's a limited number of + // things the server can ask of us on channel 0 - + // essentially, just connection.close. That, combined + // with the restrictions on pipelining, mean that + // we're OK here to handle channel 0 traffic in a + // quiescing situation, even though technically we + // should be ignoring everything except + // connection.close-ok. + switch (frame.Channel) { - frame.ReturnPayload(); + case 0: + if (!_session0.HandleFrame(in frame)) + { + return; + } + break; + default: + if (_closeReason is null && !_sessionManager.Lookup(frame.Channel).HandleFrame(in frame)) + { + return; + } + + break; } } + + frame.ReturnPayload(); + return; } /// @@ -139,7 +189,7 @@ private void HandleMainLoopException(ShutdownEventArgs reason) LogCloseError($"Unexpected connection closure: {reason}", new Exception(reason.ToString())); } - private void HardProtocolExceptionHandler(HardProtocolException hpe) + private async Task HardProtocolExceptionHandler(HardProtocolException hpe) { if (SetCloseReason(hpe.ShutdownReason)) { @@ -149,7 +199,7 @@ private void HardProtocolExceptionHandler(HardProtocolException hpe) { var cmd = new ConnectionClose(hpe.ShutdownReason.ReplyCode, hpe.ShutdownReason.ReplyText, 0, 0); _session0.Transmit(ref cmd); - ClosingLoop(); + await ClosingLoop(); } catch (IOException ioe) { @@ -165,13 +215,13 @@ private void HardProtocolExceptionHandler(HardProtocolException hpe) /// /// Loop only used while quiescing. Use only to cleanly close connection /// - private void ClosingLoop() + private async Task ClosingLoop() { try { _frameHandler.ReadTimeout = TimeSpan.Zero; // Wait for response/socket closure or timeout - ReceiveLoop(); + await MainLoopIteration().ConfigureAwait(false); } catch (ObjectDisposedException ode) { diff --git a/projects/RabbitMQ.Client/client/impl/Connection.cs b/projects/RabbitMQ.Client/client/impl/Connection.cs index f7036728d8..504e448474 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.cs @@ -32,6 +32,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -58,6 +59,8 @@ internal sealed partial class Connection : IConnection private ShutdownEventArgs? _closeReason; public ShutdownEventArgs? CloseReason => Volatile.Read(ref _closeReason); + private readonly SemaphoreSlim _writeSemaphore = new SemaphoreSlim(1); + public Connection(ConnectionConfig config, IFrameHandler frameHandler) { _config = config; @@ -71,7 +74,7 @@ public Connection(ConnectionConfig config, IFrameHandler frameHandler) _sessionManager = new SessionManager(this, 0); _session0 = new MainSession(this); - _model0 = new Model(_config, _session0); ; + _model0 = new Model(_config, _session0); ClientProperties = new Dictionary(_config.ClientProperties) { @@ -79,7 +82,7 @@ public Connection(ConnectionConfig config, IFrameHandler frameHandler) ["connection_name"] = ClientProvidedName }; - _mainLoopTask = Task.Factory.StartNew(MainLoop, TaskCreationOptions.LongRunning); + _mainLoopTask = Task.Run(MainLoop); try { Open(); @@ -373,7 +376,7 @@ private void OnShutdown(ShutdownEventArgs reason) private bool SetCloseReason(ShutdownEventArgs reason) { - return System.Threading.Interlocked.CompareExchange(ref _closeReason, reason, null) is null; + return Interlocked.CompareExchange(ref _closeReason, reason, null) is null; } private void LogCloseError(string error, Exception ex) @@ -394,9 +397,93 @@ internal void OnCallbackException(CallbackExceptionEventArgs args) _callbackExceptionWrapper.Invoke(this, args); } - internal void Write(ReadOnlyMemory memory) + internal void Write(ReadOnlySpan payload) + { + if (!_frameHandler.IsClosed) + { + _writeSemaphore.Wait(); + WriteSpan(payload); + _frameHandler.FrameWriter.FlushAsync().AsTask().Wait(); // Sync-over-async, not great, not terrible. All we can do for the sync path. + _writeSemaphore.Release(); + } + } + + internal async ValueTask WriteAsync(ReadOnlyMemory payload) + { + if (!_frameHandler.IsClosed) + { + if (!_writeSemaphore.Wait(0)) + { + await _writeSemaphore.WaitAsync().ConfigureAwait(false); + } + + WriteSpan(payload.Span); + await _frameHandler.FrameWriter.FlushAsync().ConfigureAwait(false); + _writeSemaphore.Release(); + } + } + + private void WriteSpan(ReadOnlySpan payload) + { + Span span = _frameHandler.FrameWriter.GetSpan(payload.Length); + payload.CopyTo(span); + _frameHandler.FrameWriter.Advance(payload.Length); + } + + internal void Write(ref TMethod method, ushort channelNumber) where TMethod : struct, IOutgoingAmqpMethod + { + if (!_frameHandler.IsClosed) + { + _writeSemaphore.Wait(); + Client.Impl.Framing.SerializeToFrames(ref method, _frameHandler.FrameWriter, channelNumber); + _frameHandler.FrameWriter.FlushAsync().AsTask().Wait(); // Sync-over-async, not great, not terrible. All we can do for the sync path. + _writeSemaphore.Release(); + } + } + + internal async ValueTask WriteAsync(TMethod method, ushort channelNumber) where TMethod : struct, IOutgoingAmqpMethod + { + if (!_frameHandler.IsClosed) + { + if (!_writeSemaphore.Wait(0)) + { + await _writeSemaphore.WaitAsync().ConfigureAwait(false); + } + + Client.Impl.Framing.SerializeToFrames(ref method, _frameHandler.FrameWriter, channelNumber); + await _frameHandler.FrameWriter.FlushAsync().ConfigureAwait(false); + _writeSemaphore.Release(); + } + } + + internal void Write(ref TMethod method, ref THeader header, ReadOnlyMemory body, ushort channelNumber, int maxBodyPayloadBytes) + where TMethod : struct, IOutgoingAmqpMethod + where THeader : IAmqpHeader { - _frameHandler.Write(memory); + if (!_frameHandler.IsClosed) + { + _writeSemaphore.Wait(); + Client.Impl.Framing.SerializeToFrames(ref method, ref header, body, _frameHandler.FrameWriter, channelNumber, maxBodyPayloadBytes); + _frameHandler.FrameWriter.FlushAsync().AsTask().Wait(); // Sync-over-async, not great, not terrible. All we can do for the sync path. + _writeSemaphore.Release(); + } + } + + internal async ValueTask WriteAsync(TMethod method, THeader header, ReadOnlyMemory body, ushort channelNumber, int maxBodyPayloadBytes) + where TMethod : struct, IOutgoingAmqpMethod + where THeader : IAmqpHeader + { + if (!_frameHandler.IsClosed) + { + if (!_writeSemaphore.Wait(0)) + { + await _writeSemaphore.WaitAsync().ConfigureAwait(false); + } + + Client.Impl.Framing.SerializeToFrames(ref method, ref header, body, _frameHandler.FrameWriter, channelNumber, maxBodyPayloadBytes); + await _frameHandler.FrameWriter.FlushAsync().ConfigureAwait(false); + _writeSemaphore.Release(); + } } public void Dispose() diff --git a/projects/RabbitMQ.Client/client/impl/Frame.cs b/projects/RabbitMQ.Client/client/impl/Frame.cs index b7fe64a680..98661419da 100644 --- a/projects/RabbitMQ.Client/client/impl/Frame.cs +++ b/projects/RabbitMQ.Client/client/impl/Frame.cs @@ -31,10 +31,7 @@ using System; using System.Buffers; -using System.IO; -using System.Net.Sockets; using System.Runtime.CompilerServices; -using System.Runtime.ExceptionServices; using RabbitMQ.Client.Exceptions; using RabbitMQ.Client.Framing.Impl; @@ -131,41 +128,30 @@ internal static class Heartbeat /// /// Compiler trick to directly refer to static data in the assembly, see here: https://github.com/dotnet/roslyn/pull/24621 /// - private static ReadOnlySpan Payload => new byte[] + internal static ReadOnlySpan Payload => new byte[] { Constants.FrameHeartbeat, 0, 0, // channel 0, 0, 0, 0, // payload length Constants.FrameEnd }; - - public static Memory GetHeartbeatFrame() - { - // Is returned by SocketFrameHandler.WriteLoop - byte[] buffer = ArrayPool.Shared.Rent(FrameSize); - Payload.CopyTo(buffer); - return new Memory(buffer, 0, FrameSize); - } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static ReadOnlyMemory SerializeToFrames(ref T method, ushort channelNumber) - where T : struct, IOutgoingAmqpMethod + public static void SerializeToFrames(ref TMethod method, TBufferWriter pipeWriter, ushort channelNumber) where TMethod : struct, IOutgoingAmqpMethod where TBufferWriter : IBufferWriter { int size = Method.FrameSize + method.GetRequiredBufferSize(); - - // Will be returned by SocketFrameWriter.WriteLoop - var array = ArrayPool.Shared.Rent(size); - int offset = Method.WriteTo(array, channelNumber, ref method); - - System.Diagnostics.Debug.Assert(offset == size, $"Serialized to wrong size, expect {size}, offset {offset}"); - return new ReadOnlyMemory(array, 0, size); + Span outputBuffer = pipeWriter.GetSpan(size); + int offset = Method.WriteTo(outputBuffer, channelNumber, ref method); + System.Diagnostics.Debug.Assert(offset == size, "Serialized to wrong size", "Serialized to wrong size, expect {0:N0}, offset {1:N0}", size, offset); + pipeWriter.Advance(size); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static ReadOnlyMemory SerializeToFrames(ref TMethod method, ref THeader header, ReadOnlyMemory body, ushort channelNumber, int maxBodyPayloadBytes) + public static void SerializeToFrames(ref TMethod method, ref THeader header, ReadOnlyMemory body, TBufferWriter pipeWriter, ushort channelNumber, int maxBodyPayloadBytes) where TMethod : struct, IOutgoingAmqpMethod where THeader : IAmqpHeader + where TBufferWriter : IBufferWriter { int remainingBodyBytes = body.Length; int size = Method.FrameSize + Header.FrameSize + @@ -173,20 +159,20 @@ public static ReadOnlyMemory SerializeToFrames(ref TMeth BodySegment.FrameSize * GetBodyFrameCount(maxBodyPayloadBytes, remainingBodyBytes) + remainingBodyBytes; // Will be returned by SocketFrameWriter.WriteLoop - var array = ArrayPool.Shared.Rent(size); + Span outputBuffer = pipeWriter.GetSpan(size); - int offset = Method.WriteTo(array, channelNumber, ref method); - offset += Header.WriteTo(array.AsSpan(offset), channelNumber, ref header, remainingBodyBytes); + int offset = Method.WriteTo(outputBuffer, channelNumber, ref method); + offset += Header.WriteTo(outputBuffer.Slice(offset), channelNumber, ref header, remainingBodyBytes); var bodySpan = body.Span; while (remainingBodyBytes > 0) { int frameSize = remainingBodyBytes > maxBodyPayloadBytes ? maxBodyPayloadBytes : remainingBodyBytes; - offset += BodySegment.WriteTo(array.AsSpan(offset), channelNumber, bodySpan.Slice(bodySpan.Length - remainingBodyBytes, frameSize)); + offset += BodySegment.WriteTo(outputBuffer.Slice(offset), channelNumber, bodySpan.Slice(bodySpan.Length - remainingBodyBytes, frameSize)); remainingBodyBytes -= frameSize; } - System.Diagnostics.Debug.Assert(offset == size, $"Serialized to wrong size, expect {size}, offset {offset}"); - return new ReadOnlyMemory(array, 0, size); + System.Diagnostics.Debug.Assert(offset == size, "Serialized to wrong size", "Serialized to wrong size, expect {0:N0}, offset {1:N0}", size, offset); + pipeWriter.Advance(size); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -201,7 +187,7 @@ private static int GetBodyFrameCount(int maxPayloadBytes, int length) } } - internal readonly ref struct InboundFrame + internal readonly struct InboundFrame { public readonly FrameType Type; public readonly int Channel; @@ -216,113 +202,48 @@ private InboundFrame(FrameType type, int channel, ReadOnlyMemory payload, _rentedArray = rentedArray; } - private static void ProcessProtocolHeader(Stream reader, ReadOnlySpan frameHeader) + internal static void ProcessProtocolHeader(ReadOnlySequence buffer) { - try - { - if (frameHeader[0] != 'M' || frameHeader[1] != 'Q' || frameHeader[2] != 'P') - { - throw new MalformedFrameException("Invalid AMQP protocol header from server"); - } - - int serverMinor = reader.ReadByte(); - if (serverMinor == -1) - { - throw new EndOfStreamException(); - } - - throw new PacketNotRecognizedException(frameHeader[3], frameHeader[4], frameHeader[5], serverMinor); - } - catch (EndOfStreamException) + Span protocolSpan = stackalloc byte[7]; + buffer.Slice(1, 7).CopyTo(protocolSpan); + if (protocolSpan[0] != 'M' || protocolSpan[1] != 'Q' || protocolSpan[2] != 'P') { - // Ideally we'd wrap the EndOfStreamException in the - // MalformedFrameException, but unfortunately the - // design of MalformedFrameException's superclass, - // ProtocolViolationException, doesn't permit - // this. Fortunately, the call stack in the - // EndOfStreamException is largely irrelevant at this - // point, so can safely be ignored. throw new MalformedFrameException("Invalid AMQP protocol header from server"); } + + throw new PacketNotRecognizedException(protocolSpan[3], protocolSpan[4], protocolSpan[5], protocolSpan[6]); } - internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer) + public static bool TryParseInboundFrame(ref ReadOnlySequence buffer, out InboundFrame frame) { - try - { - ReadFromStream(reader, frameHeaderBuffer, frameHeaderBuffer.Length); - } - catch (IOException ioe) + int payloadSize = NetworkOrderDeserializer.ReadInt32(buffer.Slice(3, 4)); // FIXME - throw exn on unreasonable value + int frameSize = payloadSize + 8; + if (buffer.Length < frameSize) { - // If it's a WSAETIMEDOUT SocketException, unwrap it. - // This might happen when the limit of half-open connections is - // reached. - if (ioe?.InnerException is SocketException exception && exception.SocketErrorCode == SocketError.TimedOut) - { - ExceptionDispatchInfo.Capture(exception).Throw(); - } - else - { - throw; - } + frame = default; + return false; } - byte firstByte = frameHeaderBuffer[0]; - if (firstByte == 'A') - { - // Probably an AMQP protocol header, otherwise meaningless - ProcessProtocolHeader(reader, frameHeaderBuffer.AsSpan(1, 6)); - } - - FrameType type = (FrameType)firstByte; - var frameHeaderSpan = new ReadOnlySpan(frameHeaderBuffer, 1, 6); - int channel = NetworkOrderDeserializer.ReadUInt16(frameHeaderSpan); - int payloadSize = NetworkOrderDeserializer.ReadInt32(frameHeaderSpan.Slice(2, 4)); // FIXME - throw exn on unreasonable value - - const int EndMarkerLength = 1; - // Is returned by InboundFrame.ReturnPayload in Connection.MainLoopIteration - int readSize = payloadSize + EndMarkerLength; - byte[] payloadBytes = ArrayPool.Shared.Rent(readSize); - try - { - ReadFromStream(reader, payloadBytes, readSize); - } - catch (Exception) - { - // Early EOF. - ArrayPool.Shared.Return(payloadBytes); - throw new MalformedFrameException($"Short frame - expected to read {readSize} bytes"); - } - - if (payloadBytes[payloadSize] != Constants.FrameEnd) + byte[] payloadBytes = ArrayPool.Shared.Rent(frameSize); + buffer.Slice(0, frameSize).CopyTo(payloadBytes); + if (payloadBytes[frameSize - 1] != Constants.FrameEnd) { ArrayPool.Shared.Return(payloadBytes); - throw new MalformedFrameException($"Bad frame end marker: {payloadBytes[payloadSize]}"); + frame = default; + return ThrowMalformedFrameException(payloadBytes[frameSize - 1]); } - RabbitMqClientEventSource.Log.DataReceived(payloadSize + Framing.BaseFrameSize); - return new InboundFrame(type, channel, new Memory(payloadBytes, 0, payloadSize), payloadBytes); + buffer = buffer.Slice(frameSize); + RabbitMqClientEventSource.Log.DataReceived(frameSize); + FrameType frameType = (FrameType)payloadBytes[0]; + ushort channel = NetworkOrderDeserializer.ReadUInt16(ref Unsafe.AsRef(payloadBytes[1])); + frame = new InboundFrame(frameType, channel, payloadBytes.AsMemory(7, payloadSize), payloadBytes); + return true; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void ReadFromStream(Stream reader, byte[] buffer, int toRead) + private static bool ThrowMalformedFrameException(byte value) { - int bytesRead = 0; - do - { - int read = reader.Read(buffer, bytesRead, toRead - bytesRead); - if (read == 0) - { - ThrowEndOfStream(); - } - - bytesRead += read; - } while (bytesRead != toRead); - - static void ThrowEndOfStream() - { - throw new EndOfStreamException("Reached the end of the stream. Possible authentication failure."); - } + throw new MalformedFrameException($"Bad frame end marker: {value}"); } public byte[] TakeoverPayload() diff --git a/projects/RabbitMQ.Client/client/impl/IFrameHandler.cs b/projects/RabbitMQ.Client/client/impl/IFrameHandler.cs index bdb692bd1e..22f27e51fd 100644 --- a/projects/RabbitMQ.Client/client/impl/IFrameHandler.cs +++ b/projects/RabbitMQ.Client/client/impl/IFrameHandler.cs @@ -30,6 +30,7 @@ //--------------------------------------------------------------------------- using System; +using System.IO.Pipelines; using System.Net; namespace RabbitMQ.Client.Impl @@ -54,13 +55,9 @@ internal interface IFrameHandler void Close(); - ///Read a frame from the underlying - ///transport. Returns null if the read operation timed out - ///(see Timeout property). - InboundFrame ReadFrame(); + PipeReader FrameReader { get; } - void SendHeader(); - - void Write(ReadOnlyMemory memory); + PipeWriter FrameWriter { get; } + bool IsClosed { get; } } } diff --git a/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs b/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs index a8d6e16c2e..36ba5d3c87 100644 --- a/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs @@ -59,6 +59,7 @@ protected ProtocolBase() public abstract int MajorVersion { get; } public abstract int MinorVersion { get; } public abstract int Revision { get; } + public abstract ReadOnlySpan Header { get; } public AmqpVersion Version { diff --git a/projects/RabbitMQ.Client/client/impl/RecoveryAwareModel.cs b/projects/RabbitMQ.Client/client/impl/RecoveryAwareModel.cs index 8a0d0105f3..1f84bbab87 100644 --- a/projects/RabbitMQ.Client/client/impl/RecoveryAwareModel.cs +++ b/projects/RabbitMQ.Client/client/impl/RecoveryAwareModel.cs @@ -29,7 +29,6 @@ // Copyright (c) 2007-2020 VMware, Inc. All rights reserved. //--------------------------------------------------------------------------- -using RabbitMQ.Client.client.impl; using RabbitMQ.Client.Framing.Impl; namespace RabbitMQ.Client.Impl diff --git a/projects/RabbitMQ.Client/client/impl/SessionBase.cs b/projects/RabbitMQ.Client/client/impl/SessionBase.cs index c6991f5c70..3ae5b34911 100644 --- a/projects/RabbitMQ.Client/client/impl/SessionBase.cs +++ b/projects/RabbitMQ.Client/client/impl/SessionBase.cs @@ -136,7 +136,7 @@ public virtual void Transmit(ref T cmd) where T : struct, IOutgoingAmqpMethod ThrowAlreadyClosedException(); } - Connection.Write(Framing.SerializeToFrames(ref cmd, ChannelNumber)); + Connection.Write(ref cmd, ChannelNumber); } public void Transmit(ref TMethod cmd, ref THeader header, ReadOnlyMemory body) @@ -148,7 +148,7 @@ public void Transmit(ref TMethod cmd, ref THeader header, Read ThrowAlreadyClosedException(); } - Connection.Write(Framing.SerializeToFrames(ref cmd, ref header, body, ChannelNumber, Connection.MaxPayloadSize)); + Connection.Write(ref cmd, ref header, body, ChannelNumber, Connection.MaxPayloadSize); } private void ThrowAlreadyClosedException() diff --git a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs index e28ac4742d..66062074a9 100644 --- a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs +++ b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs @@ -30,17 +30,15 @@ //--------------------------------------------------------------------------- using System; -using System.Buffers; -using System.IO; +using System.IO.Pipelines; using System.Net; using System.Net.Sockets; -using System.Runtime.InteropServices; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; +using Pipelines.Sockets.Unofficial; + using RabbitMQ.Client.Exceptions; -using RabbitMQ.Client.Logging; namespace RabbitMQ.Client.Impl { @@ -48,13 +46,10 @@ internal static class TaskExtensions { public static async Task TimeoutAfter(this Task task, TimeSpan timeout) { - if (task == await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false)) - { - await task.ConfigureAwait(false); - } - else + Task returnedTask = await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false); + if (task != returnedTask) { - Task supressErrorTask = task.ContinueWith((t, s) => t.Exception.Handle(e => true), null, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + Task supressErrorTask = returnedTask.ContinueWith((t, s) => t.Exception.Handle(e => true), null, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); throw new TimeoutException(); } } @@ -63,75 +58,78 @@ public static async Task TimeoutAfter(this Task task, TimeSpan timeout) internal sealed class SocketFrameHandler : IFrameHandler { private readonly ITcpClient _socket; - private readonly Stream _reader; - private readonly Stream _writer; - private readonly ChannelWriter> _channelWriter; - private readonly ChannelReader> _channelReader; - private readonly Task _writerTask; + + // Pipes + private readonly IDuplexPipe _pipe; + public PipeReader FrameReader => _pipe.Input; + public PipeWriter FrameWriter => _pipe.Output; + private readonly object _semaphore = new object(); - private readonly byte[] _frameHeaderBuffer; - private bool _closed; - public SocketFrameHandler(AmqpTcpEndpoint endpoint, - Func socketFactory, - TimeSpan connectionTimeout, TimeSpan readTimeout, TimeSpan writeTimeout) + public bool IsClosed { get; private set; } + + public SocketFrameHandler(AmqpTcpEndpoint endpoint, Func socketFactory, TimeSpan connectionTimeout, TimeSpan readTimeout, TimeSpan writeTimeout) { Endpoint = endpoint; - _frameHeaderBuffer = new byte[7]; - var channel = Channel.CreateUnbounded>( - new UnboundedChannelOptions - { - AllowSynchronousContinuations = false, - SingleReader = true, - SingleWriter = false - }); - - _channelReader = channel.Reader; - _channelWriter = channel.Writer; - // Resolve the hostname to know if it's even possible to even try IPv6 - IPAddress[] adds = Dns.GetHostAddresses(endpoint.HostName); - IPAddress ipv6 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetworkV6); - - if (ipv6 == default(IPAddress)) + // Let's check and see if we are connecting as an IP address first + if (IPAddress.TryParse(endpoint.HostName, out IPAddress address)) { - if (endpoint.AddressFamily == AddressFamily.InterNetworkV6) + // Connecting straight via IP so we can ignore whatever AddressFamily is set on the endpoint. We support IPv4 and IPv6. + _socket = address.AddressFamily switch { - throw new ConnectFailureException("Connection failed", new ArgumentException($"No IPv6 address could be resolved for {endpoint.HostName}")); - } + AddressFamily.InterNetwork or AddressFamily.InterNetworkV6 => ConnectUsingAddressFamily(new IPEndPoint(address, endpoint.Port), socketFactory, connectionTimeout, address.AddressFamily), + _ => throw new ConnectFailureException("Connection failed", new ArgumentException($"AddressFamily {address.AddressFamily} is not supported.")), + }; } - else if (ShouldTryIPv6(endpoint)) + else { - try - { - _socket = ConnectUsingIPv6(new IPEndPoint(ipv6, endpoint.Port), socketFactory, connectionTimeout); - } - catch (ConnectFailureException) + // We are connecting via. hostname so let's first resolve all the IP addresses for the hostname + IPAddress[] adds = Dns.GetHostAddresses(endpoint.HostName); + + // We want to connect via. IPv6 and our Socket supports IPv6 so let's try that first + if ((endpoint.AddressFamily == AddressFamily.InterNetworkV6 || endpoint.AddressFamily == AddressFamily.Unknown) && Socket.OSSupportsIPv6) { - // We resolved to a ipv6 address and tried it but it still didn't connect, try IPv4 - _socket = null; + // Let's then try to find the appropriate IP address for the hostname + IPAddress ipv6 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetworkV6); + if (ipv6 == null) + { + throw new ConnectFailureException("Connection failed", new ArgumentException($"No IPv6 address could be resolved for {endpoint.HostName}")); + } + + // Let's see if we can connect to the resolved IP address as IPv6 + try + { + _socket = ConnectUsingAddressFamily(new IPEndPoint(ipv6, endpoint.Port), socketFactory, connectionTimeout, AddressFamily.InterNetworkV6); + } + catch (ConnectFailureException) + { + // Didn't work, let's fall-back to IPv4 + _socket = null; + } } - } - if (_socket is null) - { - IPAddress ipv4 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetwork); - if (ipv4 == default(IPAddress)) + // No dice, let's fall-back to IPv4 then. + if (_socket is null) { - throw new ConnectFailureException("Connection failed", new ArgumentException($"No ip address could be resolved for {endpoint.HostName}")); + IPAddress ipv4 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetwork); + if (ipv4 == null) + { + throw new ConnectFailureException("Connection failed", new ArgumentException($"No IPv4 address could be resolved for {endpoint.HostName}")); + } + + _socket = ConnectUsingAddressFamily(new IPEndPoint(ipv4, endpoint.Port), socketFactory, connectionTimeout, AddressFamily.InterNetwork); } - _socket = ConnectUsingIPv4(new IPEndPoint(ipv4, endpoint.Port), socketFactory, connectionTimeout); } - Stream netstream = _socket.GetStream(); - netstream.ReadTimeout = (int)readTimeout.TotalMilliseconds; - netstream.WriteTimeout = (int)writeTimeout.TotalMilliseconds; + // We're done setting up our connection, let's configure timeouts and SSL if needed. + _socket.ReceiveTimeout = readTimeout; if (endpoint.Ssl.Enabled) { try { - netstream = SslHelper.TcpUpgrade(netstream, endpoint.Ssl); + _pipe = StreamConnection.GetDuplex(SslHelper.TcpUpgrade(_socket.GetStream(), endpoint.Ssl)); } catch (Exception) { @@ -139,34 +137,19 @@ public SocketFrameHandler(AmqpTcpEndpoint endpoint, throw; } } - - _reader = new BufferedStream(netstream, _socket.Client.ReceiveBufferSize); - _writer = new BufferedStream(netstream, _socket.Client.SendBufferSize); + else + { + _pipe = SocketConnection.Create(_socket.Client); + } WriteTimeout = writeTimeout; - _writerTask = Task.Run(WriteLoop); - } - public AmqpTcpEndpoint Endpoint { get; set; } - - public EndPoint LocalEndPoint - { - get { return _socket.Client.LocalEndPoint; } } - public int LocalPort - { - get { return ((IPEndPoint)LocalEndPoint).Port; } - } - - public EndPoint RemoteEndPoint - { - get { return _socket.Client.RemoteEndPoint; } - } - - public int RemotePort - { - get { return ((IPEndPoint)RemoteEndPoint).Port; } - } + public AmqpTcpEndpoint Endpoint { get; set; } + public EndPoint LocalEndPoint => _socket.Client.LocalEndPoint; + public int LocalPort => ((IPEndPoint)LocalEndPoint).Port; + public EndPoint RemoteEndPoint => _socket.Client.RemoteEndPoint; + public int RemotePort => ((IPEndPoint)RemoteEndPoint).Port; public TimeSpan ReadTimeout { @@ -188,17 +171,14 @@ public TimeSpan ReadTimeout public TimeSpan WriteTimeout { - set - { - _socket.Client.SendTimeout = (int)value.TotalMilliseconds; - } + set => _socket.Client.SendTimeout = (int)value.TotalMilliseconds; } public void Close() { lock (_semaphore) { - if (_closed || _socket == null) + if (IsClosed || _socket == null) { return; } @@ -206,8 +186,7 @@ public void Close() { try { - _channelWriter.Complete(); - _writerTask?.GetAwaiter().GetResult(); + FrameWriter.Complete(); } catch { @@ -224,97 +203,12 @@ public void Close() } finally { - _closed = true; + IsClosed = true; } } } } - public InboundFrame ReadFrame() - { - return InboundFrame.ReadFrom(_reader, _frameHeaderBuffer); - } - - public void SendHeader() - { -#if NETSTANDARD - var headerBytes = new byte[8]; -#else - Span headerBytes = stackalloc byte[8]; -#endif - - headerBytes[0] = (byte)'A'; - headerBytes[1] = (byte)'M'; - headerBytes[2] = (byte)'Q'; - headerBytes[3] = (byte)'P'; - - if (Endpoint.Protocol.Revision != 0) - { - headerBytes[4] = 0; - headerBytes[5] = (byte)Endpoint.Protocol.MajorVersion; - headerBytes[6] = (byte)Endpoint.Protocol.MinorVersion; - headerBytes[7] = (byte)Endpoint.Protocol.Revision; - } - else - { - headerBytes[4] = 1; - headerBytes[5] = 1; - headerBytes[6] = (byte)Endpoint.Protocol.MajorVersion; - headerBytes[7] = (byte)Endpoint.Protocol.MinorVersion; - } - -#if NETSTANDARD - _writer.Write(headerBytes, 0, 8); -#else - _writer.Write(headerBytes); -#endif - _writer.Flush(); - } - - public void Write(ReadOnlyMemory memory) - { - _channelWriter.TryWrite(memory); - } - - private async Task WriteLoop() - { - while (await _channelReader.WaitToReadAsync().ConfigureAwait(false)) - { - while (_channelReader.TryRead(out ReadOnlyMemory memory)) - { - MemoryMarshal.TryGetArray(memory, out ArraySegment segment); -#if NETSTANDARD - await _writer.WriteAsync(segment.Array, segment.Offset, segment.Count).ConfigureAwait(false); -#else - await _writer.WriteAsync(memory).ConfigureAwait(false); -#endif - RabbitMqClientEventSource.Log.CommandSent(segment.Count); - ArrayPool.Shared.Return(segment.Array); - } - - await _writer.FlushAsync().ConfigureAwait(false); - } - } - - private static bool ShouldTryIPv6(AmqpTcpEndpoint endpoint) - { - return Socket.OSSupportsIPv6 && endpoint.AddressFamily != AddressFamily.InterNetwork; - } - - private ITcpClient ConnectUsingIPv6(IPEndPoint endpoint, - Func socketFactory, - TimeSpan timeout) - { - return ConnectUsingAddressFamily(endpoint, socketFactory, timeout, AddressFamily.InterNetworkV6); - } - - private ITcpClient ConnectUsingIPv4(IPEndPoint endpoint, - Func socketFactory, - TimeSpan timeout) - { - return ConnectUsingAddressFamily(endpoint, socketFactory, timeout, AddressFamily.InterNetwork); - } - private ITcpClient ConnectUsingAddressFamily(IPEndPoint endpoint, Func socketFactory, TimeSpan timeout, AddressFamily family) diff --git a/projects/RabbitMQ.Client/client/impl/TcpClientAdapterHelper.cs b/projects/RabbitMQ.Client/client/impl/TcpClientAdapterHelper.cs index a9a9cf97e1..4d6f35df5a 100644 --- a/projects/RabbitMQ.Client/client/impl/TcpClientAdapterHelper.cs +++ b/projects/RabbitMQ.Client/client/impl/TcpClientAdapterHelper.cs @@ -1,20 +1,27 @@ -using System.Collections.Generic; -using System.Linq; -using System.Net; +using System.Net; using System.Net.Sockets; namespace RabbitMQ.Client.Impl { internal static class TcpClientAdapterHelper { - public static IPAddress GetMatchingHost(IReadOnlyCollection addresses, AddressFamily addressFamily) + public static IPAddress GetMatchingHost(IPAddress[] addresses, AddressFamily addressFamily) { - IPAddress ep = addresses.FirstOrDefault(a => a.AddressFamily == addressFamily); - if (ep is null && addresses.Count == 1 && addressFamily == AddressFamily.Unspecified) + if (addresses != null && addresses.Length == 1 && addressFamily == AddressFamily.Unspecified) { - return addresses.Single(); + return addresses[0]; } - return ep; + + for (int i = 0; i < addresses.Length; i++) + { + IPAddress address = addresses[i]; + if (address.AddressFamily == addressFamily) + { + return address; + } + } + + return null; } } } diff --git a/projects/RabbitMQ.Client/util/NetworkOrderDeserializer.cs b/projects/RabbitMQ.Client/util/NetworkOrderDeserializer.cs index a4b5c04aed..236d8751d4 100644 --- a/projects/RabbitMQ.Client/util/NetworkOrderDeserializer.cs +++ b/projects/RabbitMQ.Client/util/NetworkOrderDeserializer.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Buffers.Binary; using System.Runtime.CompilerServices; @@ -24,6 +25,19 @@ internal static int ReadInt32(ReadOnlySpan span) return BinaryPrimitives.ReadInt32BigEndian(span); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int ReadInt32(ReadOnlySequence buffer) + { + if (buffer.First.Length >= 4) + { + return ReadInt32(buffer.First.Span); + } + + Span span = stackalloc byte[4]; + buffer.Slice(0, 4).CopyTo(span); + return ReadInt32(span); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static long ReadInt64(ReadOnlySpan span) { @@ -43,6 +57,27 @@ internal static ushort ReadUInt16(ReadOnlySpan span) return BinaryPrimitives.ReadUInt16BigEndian(span); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static ushort ReadUInt16(ref byte source) + { + return BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref source)) : Unsafe.ReadUnaligned(ref source); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static ushort ReadUInt16(ReadOnlySequence buffer) + { + if (2 <= buffer.First.Length) + { + return ReadUInt16(buffer.First.Span); + } + else + { + Span span = stackalloc byte[2]; + buffer.Slice(0, 2).CopyTo(span); + return ReadUInt16(span); + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static uint ReadUInt32(ReadOnlySpan span) { diff --git a/projects/Unit/APIApproval.cs b/projects/Unit/APIApproval.cs index 14396c7703..8643fdfea4 100644 --- a/projects/Unit/APIApproval.cs +++ b/projects/Unit/APIApproval.cs @@ -29,7 +29,6 @@ // Copyright (c) 2007-2020 VMware, Inc. All rights reserved. //--------------------------------------------------------------------------- -using System.Reflection; using System.Threading.Tasks; using PublicApiGenerator; diff --git a/projects/Unit/Fixtures.cs b/projects/Unit/Fixtures.cs index cef17ded26..045875119c 100644 --- a/projects/Unit/Fixtures.cs +++ b/projects/Unit/Fixtures.cs @@ -398,14 +398,9 @@ internal void StartRabbitMQ() // Concurrency and Coordination // - internal void Wait(ManualResetEventSlim latch) + internal void Wait(ManualResetEventSlim latch, TimeSpan? timeSpan = null) { - Assert.True(latch.Wait(TimeSpan.FromSeconds(10)), "waiting on a latch timed out"); - } - - internal void Wait(ManualResetEventSlim latch, TimeSpan timeSpan) - { - Assert.True(latch.Wait(timeSpan), "waiting on a latch timed out"); + Assert.True(latch.Wait(timeSpan ?? TimeSpan.FromSeconds(15)), "waiting on a latch timed out"); } // diff --git a/projects/Unit/RabbitMQCtl.cs b/projects/Unit/RabbitMQCtl.cs index 3fcdf014cf..29242a6e31 100644 --- a/projects/Unit/RabbitMQCtl.cs +++ b/projects/Unit/RabbitMQCtl.cs @@ -44,9 +44,8 @@ namespace RabbitMQ.Client.Unit public static class RabbitMQCtl { private static readonly char[] newLine = new char[] { '\n' }; - private static readonly Func s_invokeRabbitMqCtl = GetRabbitMqCtlInvokeAction(); - private static Func GetRabbitMqCtlInvokeAction() + private static Process GetRabbitMqCtlInvokeAction(string args) { string precomputedArguments; string? envVariable = Environment.GetEnvironmentVariable("RABBITMQ_RABBITMQCTL_PATH"); @@ -58,11 +57,11 @@ private static Func GetRabbitMqCtlInvokeAction() { // Call docker precomputedArguments = $"exec {envVariable.Substring(DockerPrefix.Length)} rabbitmqctl "; - return args => CreateProcess("docker", precomputedArguments + args); + return CreateProcess("docker", precomputedArguments + args); } // call the path from the env var - return args => CreateProcess(envVariable, args); + return CreateProcess(envVariable, args); } // Try default @@ -84,11 +83,11 @@ private static Func GetRabbitMqCtlInvokeAction() if (IsRunningOnMonoOrDotNetCore()) { - return args => CreateProcess(path, args); + return CreateProcess(path, args); } precomputedArguments = $"/c \"\"{path}\" "; - return args => CreateProcess("cmd.exe", precomputedArguments + args); + return CreateProcess("cmd.exe", precomputedArguments + args); } // @@ -98,7 +97,7 @@ private static string ExecRabbitMQCtl(string args) { try { - using var process = s_invokeRabbitMqCtl(args); + using var process = GetRabbitMqCtlInvokeAction(args); process.Start(); process.WaitForExit(); string stderr = process.StandardError.ReadToEnd(); @@ -137,7 +136,7 @@ private static Process CreateProcess(string cmd, string arguments, string? workD private static void ReportExecFailure(string cmd, string args, string msg) { - Console.WriteLine($"Failure while running {cmd} {args}:\n{msg}"); + Xunit.Assert.True(false, $"Failure while running {cmd} {args}:\n{msg}"); } private static bool IsRunningOnMonoOrDotNetCore() diff --git a/projects/Unit/TestAsyncConsumer.cs b/projects/Unit/TestAsyncConsumer.cs index 43ccd02fa5..08032109a3 100644 --- a/projects/Unit/TestAsyncConsumer.cs +++ b/projects/Unit/TestAsyncConsumer.cs @@ -30,6 +30,7 @@ //--------------------------------------------------------------------------- using System; +using System.Buffers; using System.Security.Cryptography; using System.Text; using System.Threading; @@ -44,6 +45,7 @@ namespace RabbitMQ.Client.Unit public class TestAsyncConsumer { private readonly ITestOutputHelper _output; + private static readonly RandomNumberGenerator s_randomNumberGenerator = RandomNumberGenerator.Create(); public TestAsyncConsumer(ITestOutputHelper output) { @@ -90,11 +92,11 @@ public async Task TestBasicRoundtripConcurrent() using (IModel m = c.CreateModel()) { QueueDeclareOk q = m.QueueDeclare(); - string publish1 = get_unique_string(1024); + string publish1 = TestAsyncConsumer.get_unique_string(1024); byte[] body = Encoding.UTF8.GetBytes(publish1); m.BasicPublish("", q.QueueName, body); - string publish2 = get_unique_string(1024); + string publish2 = TestAsyncConsumer.get_unique_string(1024); body = Encoding.UTF8.GetBytes(publish2); m.BasicPublish("", q.QueueName, body); @@ -141,11 +143,8 @@ public async Task TestBasicRoundtripConcurrentManyMessages() { const int publish_total = 4096; string queueName = $"{nameof(TestBasicRoundtripConcurrentManyMessages)}-{Guid.NewGuid()}"; - - string publish1 = get_unique_string(32768); - byte[] body1 = Encoding.ASCII.GetBytes(publish1); - string publish2 = get_unique_string(32768); - byte[] body2 = Encoding.ASCII.GetBytes(publish2); + byte[] body1 = Encoding.ASCII.GetBytes(get_unique_string(4096)); + byte[] body2 = Encoding.ASCII.GetBytes(get_unique_string(4096)); var cf = new ConnectionFactory { DispatchConsumersAsync = true, ConsumerDispatchConcurrency = 2 }; @@ -174,7 +173,7 @@ public async Task TestBasicRoundtripConcurrentManyMessages() } }); - Task consumeTask = Task.Run(() => + Task consumeTask = Task.Run(async () => { var publish1SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -197,8 +196,7 @@ public async Task TestBasicRoundtripConcurrentManyMessages() consumer.Received += async (o, a) => { - string decoded = Encoding.ASCII.GetString(a.Body.ToArray()); - if (decoded == publish1) + if (a.Body.Span.SequenceEqual(body1)) { if (Interlocked.Increment(ref publish1_count) >= publish_total) { @@ -206,7 +204,7 @@ public async Task TestBasicRoundtripConcurrentManyMessages() await publish2SyncSource.Task; } } - else if (decoded == publish2) + else if (a.Body.Span.SequenceEqual(body2)) { if (Interlocked.Increment(ref publish2_count) >= publish_total) { @@ -219,7 +217,7 @@ public async Task TestBasicRoundtripConcurrentManyMessages() m.BasicConsume(queueName, true, consumer); // ensure we get a delivery - Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task); + await Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task); Assert.True(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); Assert.True(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); @@ -239,14 +237,14 @@ public void TestBasicRoundtripNoWait() using (IModel m = c.CreateModel()) { QueueDeclareOk q = m.QueueDeclare(); - byte[] body = System.Text.Encoding.UTF8.GetBytes("async-hi"); + byte[] body = Encoding.UTF8.GetBytes("async-hi"); m.BasicPublish("", q.QueueName, body); var consumer = new AsyncEventingBasicConsumer(m); var are = new AutoResetEvent(false); - consumer.Received += async (o, a) => + consumer.Received += (o, a) => { are.Set(); - await Task.Yield(); + return Task.CompletedTask; }; string tag = m.BasicConsume(q.QueueName, true, consumer); // ensure we get a delivery @@ -337,16 +335,15 @@ public void NonAsyncConsumerShouldThrowInvalidOperationException() } } - private string get_unique_string(int string_length) + private static string get_unique_string(int string_length) { - using (var rng = RandomNumberGenerator.Create()) - { - var bit_count = (string_length * 6); - var byte_count = ((bit_count + 7) / 8); // rounded up - var bytes = new byte[byte_count]; - rng.GetBytes(bytes); - return Convert.ToBase64String(bytes); - } + var bit_count = string_length * 6; + var byte_count = (bit_count + 7) / 8; // rounded up + var bytes = ArrayPool.Shared.Rent(byte_count); + s_randomNumberGenerator.GetBytes(bytes); + var result = Convert.ToBase64String(bytes); + ArrayPool.Shared.Return(bytes); + return result; } } } diff --git a/projects/Unit/TestBasicPublish.cs b/projects/Unit/TestBasicPublish.cs index e13103db3a..4a48a7a042 100644 --- a/projects/Unit/TestBasicPublish.cs +++ b/projects/Unit/TestBasicPublish.cs @@ -3,7 +3,7 @@ using System.Threading.Tasks; using RabbitMQ.Client.Events; -using RabbitMQ.Client.Framing; + using Xunit; namespace RabbitMQ.Client.Unit diff --git a/projects/Unit/TestFrameFormatting.cs b/projects/Unit/TestFrameFormatting.cs index 360cf084f0..3e6f8a23f8 100644 --- a/projects/Unit/TestFrameFormatting.cs +++ b/projects/Unit/TestFrameFormatting.cs @@ -30,8 +30,6 @@ //--------------------------------------------------------------------------- using System; -using System.Buffers; -using System.Runtime.InteropServices; using RabbitMQ.Client.Framing.Impl; @@ -44,28 +42,17 @@ public class TestFrameFormatting : WireFormattingFixture [Fact] public void HeartbeatFrame() { - Memory memory = Impl.Framing.Heartbeat.GetHeartbeatFrame(); - Span frameSpan = memory.Span; + ReadOnlySpan frameSpan = Impl.Framing.Heartbeat.Payload; - try - { - Assert.Equal(8, frameSpan.Length); - Assert.Equal(Constants.FrameHeartbeat, frameSpan[0]); - Assert.Equal(0, frameSpan[1]); // channel - Assert.Equal(0, frameSpan[2]); // channel - Assert.Equal(0, frameSpan[3]); // payload size - Assert.Equal(0, frameSpan[4]); // payload size - Assert.Equal(0, frameSpan[5]); // payload size - Assert.Equal(0, frameSpan[6]); // payload size - Assert.Equal(Constants.FrameEnd, frameSpan[7]); - } - finally - { - if (MemoryMarshal.TryGetArray(memory, out ArraySegment segment)) - { - ArrayPool.Shared.Return(segment.Array); - } - } + Assert.Equal(8, frameSpan.Length); + Assert.Equal(Constants.FrameHeartbeat, frameSpan[0]); + Assert.Equal(0, frameSpan[1]); // channel + Assert.Equal(0, frameSpan[2]); // channel + Assert.Equal(0, frameSpan[3]); // payload size + Assert.Equal(0, frameSpan[4]); // payload size + Assert.Equal(0, frameSpan[5]); // payload size + Assert.Equal(0, frameSpan[6]); // payload size + Assert.Equal(Constants.FrameEnd, frameSpan[7]); } [Fact] diff --git a/projects/Unit/TestPublisherConfirms.cs b/projects/Unit/TestPublisherConfirms.cs index 3e2b5b013c..a44ebd3572 100644 --- a/projects/Unit/TestPublisherConfirms.cs +++ b/projects/Unit/TestPublisherConfirms.cs @@ -52,46 +52,42 @@ public TestPublisherConfirms(ITestOutputHelper output) : base(output) var rnd = new Random(); _body = new byte[4096]; rnd.NextBytes(_body); - } [Fact] - public void TestWaitForConfirmsWithoutTimeout() + public async Task TestWaitForConfirmsWithoutTimeout() { - TestWaitForConfirms(200, (ch) => + await TestWaitForConfirms(200, async (ch) => { - Assert.True(ch.WaitForConfirmsAsync().GetAwaiter().GetResult()); + Assert.True(await ch.WaitForConfirmsAsync()); }); } [Fact] - public void TestWaitForConfirmsWithTimeout() + public async Task TestWaitForConfirmsWithTimeout() { - TestWaitForConfirms(200, (ch) => + await TestWaitForConfirms(200, async (ch) => { using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(4))) { - Assert.True(ch.WaitForConfirmsAsync(cts.Token).GetAwaiter().GetResult()); + Assert.True(await ch.WaitForConfirmsAsync(cts.Token)); } }); } [Fact] - public void TestWaitForConfirmsWithTimeout_AllMessagesAcked_WaitingHasTimedout_ReturnTrue() + public async Task TestWaitForConfirmsWithTimeout_AllMessagesAcked_WaitingHasTimedout_ReturnTrue() { - TestWaitForConfirms(10000, (ch) => + await TestWaitForConfirms(1000, async (ch) => { - using (var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(1))) - { - Assert.Throws(() => ch.WaitForConfirmsAsync(cts.Token).GetAwaiter().GetResult()); - } + await Assert.ThrowsAsync(async () => await ch.WaitForConfirmsAsync(new CancellationToken(true))); }); } [Fact] - public void TestWaitForConfirmsWithTimeout_MessageNacked_WaitingHasTimedout_ReturnFalse() + public async Task TestWaitForConfirmsWithTimeout_MessageNacked_WaitingHasTimedout_ReturnFalse() { - TestWaitForConfirms(2000, (ch) => + await TestWaitForConfirms(2000, async (ch) => { IModel actualModel = ((AutorecoveringModel)ch).InnerChannel; actualModel @@ -101,7 +97,7 @@ public void TestWaitForConfirmsWithTimeout_MessageNacked_WaitingHasTimedout_Retu using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(4))) { - Assert.False(ch.WaitForConfirmsAsync(cts.Token).GetAwaiter().GetResult()); + Assert.False(await ch.WaitForConfirmsAsync(cts.Token)); } }); } @@ -143,7 +139,7 @@ public async Task TestWaitForConfirmsWithEvents() } } - protected void TestWaitForConfirms(int numberOfMessagesToPublish, Action fn) + protected async Task TestWaitForConfirms(int numberOfMessagesToPublish, Func fn) { using (IModel ch = _conn.CreateModel()) { @@ -157,7 +153,7 @@ protected void TestWaitForConfirms(int numberOfMessagesToPublish, Action try { - fn(ch); + await fn(ch).ConfigureAwait(false); } finally { From 83eda9fb9c2bea5c97252b362ef1378a6805357c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=C3=A1n=20J=C3=B6kull=20Sigur=C3=B0arson?= Date: Mon, 25 Apr 2022 15:24:25 +0000 Subject: [PATCH 2/4] Cleaning up tests. --- projects/Unit/Fixtures.cs | 112 ++++++--------- projects/Unit/Helper/DebugUtil.cs | 35 +++-- projects/Unit/RabbitMQCtl.cs | 67 +++++---- projects/Unit/TestAsyncConsumer.cs | 4 +- projects/Unit/TestBasicGet.cs | 17 ++- projects/Unit/TestConnectionBlocked.cs | 66 ++++----- projects/Unit/TestConnectionRecovery.cs | 133 +++++++----------- projects/Unit/TestConsumerCancelNotify.cs | 33 ++--- projects/Unit/TestConsumerExceptions.cs | 11 +- .../Unit/TestConsumerOperationDispatch.cs | 4 +- projects/Unit/TestContentHeaderCodec.cs | 20 ++- projects/Unit/TestEventingConsumer.cs | 13 +- projects/Unit/TestFieldTableFormatting.cs | 5 + .../Unit/TestFieldTableFormattingGeneric.cs | 5 + projects/Unit/TestFrameFormatting.cs | 5 + projects/Unit/TestHeartbeats.cs | 6 +- projects/Unit/TestInvalidAck.cs | 6 +- projects/Unit/TestMainLoop.cs | 7 +- projects/Unit/TestMethodArgumentCodec.cs | 20 ++- .../Unit/TestNetworkByteOrderSerialization.cs | 20 ++- projects/Unit/TestPublishSharedModel.cs | 9 +- projects/Unit/TestUpdateSecret.cs | 2 +- projects/Unit/WireFormattingFixture.cs | 20 ++- 23 files changed, 296 insertions(+), 324 deletions(-) diff --git a/projects/Unit/Fixtures.cs b/projects/Unit/Fixtures.cs index 045875119c..f7b3ceaffa 100644 --- a/projects/Unit/Fixtures.cs +++ b/projects/Unit/Fixtures.cs @@ -34,8 +34,11 @@ using System; using System.Collections.Generic; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; +using System.Threading.Tasks; + using RabbitMQ.Client.Framing.Impl; using Xunit; using Xunit.Abstractions; @@ -97,37 +100,28 @@ protected virtual void ReleaseResources() // Connections // - internal AutorecoveringConnection CreateAutorecoveringConnection() - { - return CreateAutorecoveringConnection(RECOVERY_INTERVAL); - } - internal AutorecoveringConnection CreateAutorecoveringConnection(IList hostnames) - { - return CreateAutorecoveringConnection(RECOVERY_INTERVAL, hostnames); - } - - internal AutorecoveringConnection CreateAutorecoveringConnection(TimeSpan interval) { var cf = new ConnectionFactory { AutomaticRecoveryEnabled = true, - NetworkRecoveryInterval = interval + // tests that use this helper will likely list unreachable hosts, + // make sure we time out quickly on those + RequestedConnectionTimeout = TimeSpan.FromSeconds(1), + NetworkRecoveryInterval = RECOVERY_INTERVAL }; - return (AutorecoveringConnection)cf.CreateConnection($"{_testDisplayName}:{Guid.NewGuid()}"); + return (AutorecoveringConnection)cf.CreateConnection(hostnames, $"{_testDisplayName}:{Guid.NewGuid()}"); } - internal AutorecoveringConnection CreateAutorecoveringConnection(TimeSpan interval, IList hostnames) + internal AutorecoveringConnection CreateAutorecoveringConnection(TimeSpan? interval = null) { + interval ??= RECOVERY_INTERVAL; var cf = new ConnectionFactory { AutomaticRecoveryEnabled = true, - // tests that use this helper will likely list unreachable hosts, - // make sure we time out quickly on those - RequestedConnectionTimeout = TimeSpan.FromSeconds(1), - NetworkRecoveryInterval = interval + NetworkRecoveryInterval = interval.Value }; - return (AutorecoveringConnection)cf.CreateConnection(hostnames, $"{_testDisplayName}:{Guid.NewGuid()}"); + return (AutorecoveringConnection)cf.CreateConnection($"{_testDisplayName}:{Guid.NewGuid()}"); } internal AutorecoveringConnection CreateAutorecoveringConnection(IList endpoints) @@ -210,13 +204,13 @@ internal byte[] RandomMessageBody() return _encoding.GetBytes(Guid.NewGuid().ToString()); } - internal string DeclareNonDurableExchange(IModel m, string x) + internal string DeclareNonDurableExchange(IModel m, [CallerMemberName]string x = null) { m.ExchangeDeclare(x, "fanout", false); return x; } - internal string DeclareNonDurableExchangeNoWait(IModel m, string x) + internal string DeclareNonDurableExchangeNoWait(IModel m, [CallerMemberName]string x = null) { m.ExchangeDeclareNoWait(x, "fanout", false, false, null); return x; @@ -226,40 +220,35 @@ internal string DeclareNonDurableExchangeNoWait(IModel m, string x) // Queues // - internal string GenerateQueueName() - { - return $"queue{Guid.NewGuid()}"; - } - - internal void WithTemporaryNonExclusiveQueue(Action action) + internal string GenerateQueueName([CallerMemberName] string callerName = null) { - WithTemporaryNonExclusiveQueue(_model, action); + return $"queue{Guid.NewGuid()}{callerName}"; } - internal void WithTemporaryNonExclusiveQueue(IModel model, Action action) + internal Task WithTemporaryNonExclusiveQueueAsync(Func action) { - WithTemporaryNonExclusiveQueue(model, action, GenerateQueueName()); + return WithTemporaryNonExclusiveQueueAsync(_model, action); } - internal void WithTemporaryNonExclusiveQueue(IModel model, Action action, string queue) + internal async Task WithTemporaryNonExclusiveQueueAsync(IModel model, Func action, [CallerMemberName] string queueName = null) { try { - model.QueueDeclare(queue, false, false, false, null); - action(model, queue); + model.QueueDeclare(queueName, false, false, false, null); + await action(model, queueName); } finally { - WithTemporaryModel(tm => tm.QueueDelete(queue)); + WithTemporaryModel(tm => tm.QueueDelete(queueName)); } } - internal void WithTemporaryQueueNoWait(IModel model, Action action, string queue) + internal async Task WithTemporaryQueueNoWaitAsync(IModel model, Func action, string queue) { try { model.QueueDeclareNoWait(queue, false, true, false, null); - action(model, queue); + await action(model, queue); } finally { @@ -272,26 +261,26 @@ internal void EnsureNotEmpty(string q, string body) WithTemporaryModel(x => x.BasicPublish("", q, _encoding.GetBytes(body))); } - internal void WithNonEmptyQueue(Action action) + internal Task WithNonEmptyQueueAsync(Func action) { - WithNonEmptyQueue(action, "msg"); + return WithNonEmptyQueueAsync(action, "msg"); } - internal void WithNonEmptyQueue(Action action, string msg) + internal Task WithNonEmptyQueueAsync(Func action, string msg) { - WithTemporaryNonExclusiveQueue((m, q) => + return WithTemporaryNonExclusiveQueueAsync(async (m, q) => { EnsureNotEmpty(q, msg); - action(m, q); + await action(m, q); }); } - internal void WithEmptyQueue(Action action) + internal Task WithEmptyQueueAsync(Func action) { - WithTemporaryNonExclusiveQueue((model, queue) => + return WithTemporaryNonExclusiveQueueAsync(async (model, queue) => { model.QueuePurge(queue); - action(model, queue); + await action(model, queue); }); } @@ -338,30 +327,18 @@ internal bool InitiatedByPeerOrLibrary(ShutdownEventArgs evt) return !(evt.Initiator == ShutdownInitiator.Application); } - // - // Concurrency - // - - internal void WaitOn(object o) - { - lock (o) - { - Monitor.Wait(o, TimingFixture.TestTimeout); - } - } - // // Flow Control // - internal void Block() + internal Task BlockAsync() { - RabbitMQCtl.Block(_conn, _encoding); + return RabbitMQCtl.BlockAsync(_conn, _encoding, _output); } internal void Unblock() { - RabbitMQCtl.Unblock(); + RabbitMQCtl.Unblock(_output); } // @@ -370,28 +347,23 @@ internal void Unblock() internal void CloseConnection(IConnection conn) { - RabbitMQCtl.CloseConnection(conn); - } - - internal void CloseAllConnections() - { - RabbitMQCtl.CloseAllConnections(); + RabbitMQCtl.CloseConnection(conn, _output); } - internal void RestartRabbitMQ() + internal Task RestartRabbitMQAsync() { - RabbitMQCtl.RestartRabbitMQ(); + return RabbitMQCtl.RestartRabbitMQAsync(_output); } internal void StopRabbitMQ() { - RabbitMQCtl.StopRabbitMQ(); + RabbitMQCtl.StopRabbitMQ(_output); } internal void StartRabbitMQ() { - RabbitMQCtl.StartRabbitMQ(); - RabbitMQCtl.AwaitRabbitMQ(); + RabbitMQCtl.StartRabbitMQ(_output); + RabbitMQCtl.AwaitRabbitMQ(_output); } // @@ -400,7 +372,7 @@ internal void StartRabbitMQ() internal void Wait(ManualResetEventSlim latch, TimeSpan? timeSpan = null) { - Assert.True(latch.Wait(timeSpan ?? TimeSpan.FromSeconds(15)), "waiting on a latch timed out"); + Assert.True(latch.Wait(timeSpan ?? TimeSpan.FromSeconds(30)), "waiting on a latch timed out"); } // diff --git a/projects/Unit/Helper/DebugUtil.cs b/projects/Unit/Helper/DebugUtil.cs index 9ee8ed76bd..ae5b303629 100644 --- a/projects/Unit/Helper/DebugUtil.cs +++ b/projects/Unit/Helper/DebugUtil.cs @@ -33,6 +33,9 @@ using System.Collections; using System.IO; using System.Reflection; +using System.Text; + +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { @@ -42,44 +45,40 @@ namespace RabbitMQ.Client.Unit /// internal static class DebugUtil { - ///Print a hex dump of the supplied bytes to stdout. - public static void Dump(byte[] bytes) - { - Dump(bytes, Console.Out); - } - ///Print a hex dump of the supplied bytes to the supplied TextWriter. - public static void Dump(byte[] bytes, TextWriter writer) + public static void Dump(byte[] bytes, ITestOutputHelper writer) { int rowlen = 16; for (int count = 0; count < bytes.Length; count += rowlen) { int thisRow = Math.Min(bytes.Length - count, rowlen); + StringBuilder builder = new StringBuilder(); - writer.Write("{0:X8}: ", count); + builder.AppendFormat("{0:X8}: ", count); for (int i = 0; i < thisRow; i++) { - writer.Write("{0:X2}", bytes[count + i]); + builder.AppendFormat("{0:X2}", bytes[count + i]); } for (int i = 0; i < (rowlen - thisRow); i++) { - writer.Write(" "); + builder.Append(" "); } - writer.Write(" "); + builder.Append(" "); for (int i = 0; i < thisRow; i++) { if (bytes[count + i] >= 32 && bytes[count + i] < 128) { - writer.Write((char)bytes[count + i]); + builder.Append((char)bytes[count + i]); } else { - writer.Write('.'); + builder.Append('.'); } } - writer.WriteLine(); + + writer.WriteLine(builder.ToString()); } if (bytes.Length % 16 != 0) { @@ -89,15 +88,15 @@ public static void Dump(byte[] bytes, TextWriter writer) ///Prints an indented key/value pair; used by DumpProperties() ///Recurses into the value using DumpProperties(). - public static void DumpKeyValue(string key, object value, TextWriter writer, int indent) + public static void DumpKeyValue(string key, object value, ITestOutputHelper writer, int indent) { string prefix = $"{new string(' ', indent + 2)}{key}: "; - writer.Write(prefix); + writer.WriteLine(prefix); DumpProperties(value, writer, indent + 2); } ///Dump properties of objects to the supplied writer. - public static void DumpProperties(object value, TextWriter writer, int indent) + public static void DumpProperties(object value, ITestOutputHelper writer, int indent) { switch (value) { @@ -112,7 +111,7 @@ public static void DumpProperties(object value, TextWriter writer, int indent) Dump(byteVal, writer); break; case ValueType _: - writer.WriteLine(value); + writer.WriteLine($"{value}"); break; case IDictionary dictionary: { diff --git a/projects/Unit/RabbitMQCtl.cs b/projects/Unit/RabbitMQCtl.cs index 29242a6e31..ab3c4b41d6 100644 --- a/projects/Unit/RabbitMQCtl.cs +++ b/projects/Unit/RabbitMQCtl.cs @@ -37,6 +37,9 @@ using System.Text; using System.Text.RegularExpressions; using System.Threading; +using System.Threading.Tasks; + +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { @@ -93,8 +96,9 @@ private static Process GetRabbitMqCtlInvokeAction(string args) // // Shelling Out // - private static string ExecRabbitMQCtl(string args) + private static string ExecRabbitMQCtl(string args, ITestOutputHelper outputHelper) { + Stopwatch timer = Stopwatch.StartNew(); try { using var process = GetRabbitMqCtlInvokeAction(args); @@ -108,15 +112,22 @@ private static string ExecRabbitMQCtl(string args) ReportExecFailure("rabbitmqctl", args, $"{stderr}\n{stdout}"); } + outputHelper?.WriteLine($"Successfully executed RabbitMQCtl {args} in {timer.ElapsedMilliseconds:N0} ms."); return stdout; } catch (Exception e) { + outputHelper?.WriteLine($"Failed to executed RabbitMQCtl {args} in {timer.ElapsedMilliseconds:N0} ms."); ReportExecFailure("rabbitmqctl", args, e.Message); throw; } } + private static void Process_Exited(object? sender, EventArgs e) + { + throw new NotImplementedException(); + } + private static Process CreateProcess(string cmd, string arguments, string? workDirectory = null) { return new Process @@ -151,11 +162,11 @@ private static bool IsRunningOnMonoOrDotNetCore() // // Flow Control // - public static void Block(IConnection conn, Encoding encoding) + public static async Task BlockAsync(IConnection conn, Encoding encoding, ITestOutputHelper outputHelper) { - ExecRabbitMQCtl("set_vm_memory_high_watermark 0.000000001"); + ExecRabbitMQCtl("set_vm_memory_high_watermark 0.000000001", outputHelper); // give rabbitmqctl some time to do its job - Thread.Sleep(1200); + await Task.Delay(1200); Publish(conn, encoding); } @@ -165,20 +176,20 @@ public static void Publish(IConnection conn, Encoding encoding) ch.BasicPublish("amq.fanout", "", encoding.GetBytes("message")); } - public static void Unblock() + public static void Unblock(ITestOutputHelper outputHelper) { - ExecRabbitMQCtl("set_vm_memory_high_watermark 0.4"); + ExecRabbitMQCtl("set_vm_memory_high_watermark 0.4", outputHelper); } - public static void CloseConnection(IConnection conn) + public static void CloseConnection(IConnection conn, ITestOutputHelper outputHelper) { - CloseConnection(GetConnectionPid(conn.ClientProvidedName)); + CloseConnection(GetConnectionPid(conn.ClientProvidedName, outputHelper), outputHelper); } private static readonly Regex s_getConnectionProperties = new Regex(@"^(?<[^>]*>)\s\[.*""connection_name"",""(?[^""]*)"".*\]$", RegexOptions.Multiline | RegexOptions.Compiled); - private static string GetConnectionPid(string connectionName) + private static string GetConnectionPid(string connectionName, ITestOutputHelper outputHelper) { - string stdout = ExecRabbitMQCtl("list_connections --silent pid client_properties"); + string stdout = ExecRabbitMQCtl("list_connections --silent pid client_properties", outputHelper); var match = s_getConnectionProperties.Match(stdout); while (match.Success) @@ -194,46 +205,46 @@ private static string GetConnectionPid(string connectionName) throw new Exception($"No connection found with name: {connectionName}"); } - private static void CloseConnection(string pid) + private static void CloseConnection(string pid, ITestOutputHelper outputHelper) { - ExecRabbitMQCtl($"close_connection \"{pid}\" \"Closed via rabbitmqctl\""); + ExecRabbitMQCtl($"close_connection \"{pid}\" \"Closed via rabbitmqctl\"", outputHelper); } - public static void CloseAllConnections() + public static void CloseAllConnections(ITestOutputHelper outputHelper) { - foreach (var pid in EnumerateConnectionsPid()) + foreach (var pid in EnumerateConnectionsPid(outputHelper)) { - CloseConnection(pid); + CloseConnection(pid, outputHelper); } } - private static string[] EnumerateConnectionsPid() + private static string[] EnumerateConnectionsPid(ITestOutputHelper outputHelper) { - string rabbitmqCtlResult = ExecRabbitMQCtl("list_connections --silent pid"); + string rabbitmqCtlResult = ExecRabbitMQCtl("list_connections --silent pid", outputHelper); return rabbitmqCtlResult.Split(newLine, StringSplitOptions.RemoveEmptyEntries); } - public static void RestartRabbitMQ() + public static async Task RestartRabbitMQAsync(ITestOutputHelper outputHelper) { - StopRabbitMQ(); - Thread.Sleep(500); - StartRabbitMQ(); - AwaitRabbitMQ(); + StopRabbitMQ(outputHelper); + await Task.Delay(500); + StartRabbitMQ(outputHelper); + AwaitRabbitMQ(outputHelper); } - public static void StopRabbitMQ() + public static void StopRabbitMQ(ITestOutputHelper outputHelper) { - ExecRabbitMQCtl("stop_app"); + ExecRabbitMQCtl("stop_app", outputHelper); } - public static void StartRabbitMQ() + public static void StartRabbitMQ(ITestOutputHelper outputHelper) { - ExecRabbitMQCtl("start_app"); + ExecRabbitMQCtl("start_app", outputHelper); } - public static void AwaitRabbitMQ() + public static void AwaitRabbitMQ(ITestOutputHelper outputHelper) { - ExecRabbitMQCtl("await_startup"); + ExecRabbitMQCtl("await_startup", outputHelper); } } } diff --git a/projects/Unit/TestAsyncConsumer.cs b/projects/Unit/TestAsyncConsumer.cs index 08032109a3..0f6799bb70 100644 --- a/projects/Unit/TestAsyncConsumer.cs +++ b/projects/Unit/TestAsyncConsumer.cs @@ -65,10 +65,10 @@ public void TestBasicRoundtrip() m.BasicPublish("", q.QueueName, body); var consumer = new AsyncEventingBasicConsumer(m); var are = new AutoResetEvent(false); - consumer.Received += async (o, a) => + consumer.Received += (o, a) => { are.Set(); - await Task.Yield(); + return Task.CompletedTask; }; string tag = m.BasicConsume(q.QueueName, true, consumer); // ensure we get a delivery diff --git a/projects/Unit/TestBasicGet.cs b/projects/Unit/TestBasicGet.cs index 9bf653e6fa..c76de0c4aa 100644 --- a/projects/Unit/TestBasicGet.cs +++ b/projects/Unit/TestBasicGet.cs @@ -29,6 +29,8 @@ // Copyright (c) 2007-2020 VMware, Inc. All rights reserved. //--------------------------------------------------------------------------- +using System.Threading.Tasks; + using RabbitMQ.Client.Exceptions; using Xunit; @@ -43,36 +45,39 @@ public TestBasicGet(ITestOutputHelper output) : base(output) } [Fact] - public void TestBasicGetWithClosedChannel() + public async Task TestBasicGetWithClosedChannel() { - WithNonEmptyQueue((_, q) => + await WithNonEmptyQueueAsync((_, q) => { WithClosedModel(cm => { Assert.Throws(() => cm.BasicGet(q, true)); }); + return Task.CompletedTask; }); } [Fact] - public void TestBasicGetWithEmptyResponse() + public async Task TestBasicGetWithEmptyResponse() { - WithEmptyQueue((model, queue) => + await WithEmptyQueueAsync((model, queue) => { BasicGetResult res = model.BasicGet(queue, false); Assert.Null(res); + return Task.CompletedTask; }); } [Fact] - public void TestBasicGetWithNonEmptyResponseAndAutoAckMode() + public async Task TestBasicGetWithNonEmptyResponseAndAutoAckMode() { const string msg = "for basic.get"; - WithNonEmptyQueue((model, queue) => + await WithNonEmptyQueueAsync((model, queue) => { BasicGetResult res = model.BasicGet(queue, true); Assert.Equal(msg, _encoding.GetString(res.Body.ToArray())); AssertMessageCount(queue, 0); + return Task.CompletedTask; }, msg); } } diff --git a/projects/Unit/TestConnectionBlocked.cs b/projects/Unit/TestConnectionBlocked.cs index c6bda72b51..5f214e7252 100644 --- a/projects/Unit/TestConnectionBlocked.cs +++ b/projects/Unit/TestConnectionBlocked.cs @@ -42,68 +42,54 @@ namespace RabbitMQ.Client.Unit public class TestConnectionBlocked : IntegrationFixture { private readonly ManualResetEventSlim _connDisposed = new ManualResetEventSlim(false); - private readonly object _lockObject = new object(); - private bool _notified; + private readonly ManualResetEventSlim _connectionUnblocked = new ManualResetEventSlim(false); public TestConnectionBlocked(ITestOutputHelper output) : base(output) { } [Fact] - public void TestConnectionBlockedNotification() + public async Task TestConnectionBlockedNotification() { - _notified = false; _conn.ConnectionBlocked += HandleBlocked; _conn.ConnectionUnblocked += HandleUnblocked; - try - { - Block(); - - lock (_lockObject) - { - if (!_notified) - { - Monitor.Wait(_lockObject, TimeSpan.FromSeconds(15)); - } - } - - if (!_notified) - { - Assert.True(false, "Unblock notification not received."); - } - } - finally - { - Unblock(); - } + await BlockAsync(); + Assert.True(_connectionUnblocked.Wait(TimeSpan.FromSeconds(15))); + ResetEvent(); } [Fact] - public void TestDisposeOnBlockedConnectionDoesNotHang() + public async Task TestDisposeOnBlockedConnectionDoesNotHang() { - _notified = false; - try { - Block(); - - Task.Factory.StartNew(DisposeConnection); - - if (!_connDisposed.Wait(TimeSpan.FromSeconds(20))) + await BlockAsync(); + using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20))) { - Assert.True(false, "Dispose must have finished within 20 seconds after starting"); + await Task.Run(DisposeConnection, cts.Token); } } + catch (TaskCanceledException) + { + Assert.True(false, "Dispose must have finished within 20 seconds after starting"); + } finally { - Unblock(); + ResetEvent(); } } - protected override void ReleaseResources() + private void ResetEvent() { - Unblock(); + if (!_connectionUnblocked.IsSet) + { + Unblock(); + } + else + { + _connectionUnblocked.Reset(); + } } private void HandleBlocked(object sender, ConnectionBlockedEventArgs args) @@ -113,11 +99,7 @@ private void HandleBlocked(object sender, ConnectionBlockedEventArgs args) private void HandleUnblocked(object sender, EventArgs ea) { - lock (_lockObject) - { - _notified = true; - Monitor.PulseAll(_lockObject); - } + _connectionUnblocked.Set(); } private void DisposeConnection() diff --git a/projects/Unit/TestConnectionRecovery.cs b/projects/Unit/TestConnectionRecovery.cs index eef7b84e33..1240d5af6c 100644 --- a/projects/Unit/TestConnectionRecovery.cs +++ b/projects/Unit/TestConnectionRecovery.cs @@ -31,6 +31,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.Events; @@ -80,7 +81,7 @@ protected override void ReleaseResources() _conn.Close(); } - Unblock(); + //Unblock(); } [Fact] @@ -150,13 +151,13 @@ public void TestBasicRejectAfterChannelRecovery() } [Fact] - public void TestBasicAckAfterBasicGetAndChannelRecovery() + public async Task TestBasicAckAfterBasicGetAndChannelRecovery() { string q = GenerateQueueName(); _model.QueueDeclare(q, false, false, false, null); // create an offset _model.BasicPublish("", q, _messageBody); - Thread.Sleep(50); + await Task.Delay(50); BasicGetResult g = _model.BasicGet(q, false); CloseAndWaitForRecovery(); Assert.True(_conn.IsOpen); @@ -168,7 +169,7 @@ public void TestBasicAckAfterBasicGetAndChannelRecovery() } [Fact] - public void TestBasicAckEventHandlerRecovery() + public async Task TestBasicAckEventHandlerRecovery() { _model.ConfirmSelect(); var latch = new ManualResetEventSlim(false); @@ -179,7 +180,11 @@ public void TestBasicAckEventHandlerRecovery() CloseAndWaitForRecovery(); Assert.True(_model.IsOpen); - WithTemporaryNonExclusiveQueue(_model, (m, q) => m.BasicPublish("", q, _messageBody)); + await WithTemporaryNonExclusiveQueueAsync(_model, (m, q) => + { + m.BasicPublish("", q, _messageBody); + return Task.CompletedTask; + }); Wait(latch); } @@ -230,7 +235,7 @@ public void TestBasicConnectionRecoveryWithEndpointList() } [Fact] - public void TestBasicConnectionRecoveryStopsAfterManualClose() + public async Task TestBasicConnectionRecoveryStopsAfterManualClose() { Assert.True(_conn.IsOpen); AutorecoveringConnection c = CreateAutorecoveringConnection(); @@ -243,9 +248,9 @@ public void TestBasicConnectionRecoveryStopsAfterManualClose() latch.WaitOne(30000); // we got the failed reconnection event. bool triedRecoveryAfterClose = false; c.Close(); - Thread.Sleep(5000); + await Task.Delay(5000); c.ConnectionRecoveryError += (o, args) => triedRecoveryAfterClose = true; - Thread.Sleep(10000); + await Task.Delay(10000); Assert.False(triedRecoveryAfterClose); } finally @@ -272,10 +277,10 @@ public void TestBasicConnectionRecoveryWithEndpointListAndUnreachableHosts() } [Fact] - public void TestBasicConnectionRecoveryOnBrokerRestart() + public async Task TestBasicConnectionRecoveryOnBrokerRestart() { Assert.True(_conn.IsOpen); - RestartServerAndWaitForRecovery(); + await RestartServerAndWaitForRecoveryAsync(); Assert.True(_conn.IsOpen); } @@ -288,57 +293,59 @@ public void TestBasicModelRecovery() } [Fact] - public void TestBasicModelRecoveryOnServerRestart() + public async Task TestBasicModelRecoveryOnServerRestart() { Assert.True(_model.IsOpen); - RestartServerAndWaitForRecovery(); + await RestartServerAndWaitForRecoveryAsync(); Assert.True(_model.IsOpen); } [Fact] - public void TestBlockedListenersRecovery() + public async Task TestBlockedListenersRecovery() { var latch = new ManualResetEventSlim(false); _conn.ConnectionBlocked += (c, reason) => latch.Set(); CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); - Block(); + await BlockAsync(); Wait(latch); Unblock(); } [Fact] - public void TestClientNamedQueueRecovery() + public async Task TestClientNamedQueueRecovery() { string s = "dotnet-client.test.recovery.q1"; - WithTemporaryNonExclusiveQueue(_model, (m, q) => + await WithTemporaryNonExclusiveQueueAsync(_model, (m, q) => { CloseAndWaitForRecovery(); AssertQueueRecovery(m, q, false); _model.QueueDelete(q); + return Task.CompletedTask; }, s); } [Fact] - public void TestClientNamedQueueRecoveryNoWait() + public async Task TestClientNamedQueueRecoveryNoWait() { string s = "dotnet-client.test.recovery.q1-nowait"; - WithTemporaryQueueNoWait(_model, (m, q) => + await WithTemporaryQueueNoWaitAsync(_model, (m, q) => { CloseAndWaitForRecovery(); AssertQueueRecovery(m, q); + return Task.CompletedTask; }, s); } [Fact] - public void TestClientNamedQueueRecoveryOnServerRestart() + public async Task TestClientNamedQueueRecoveryOnServerRestart() { string s = "dotnet-client.test.recovery.q1"; - WithTemporaryNonExclusiveQueue(_model, (m, q) => + await WithTemporaryNonExclusiveQueueAsync(_model, async (m, q) => { - RestartServerAndWaitForRecovery(); + await RestartServerAndWaitForRecoveryAsync(); AssertQueueRecovery(m, q, false); _model.QueueDelete(q); }, s); @@ -537,22 +544,20 @@ public void TestDeclarationOfManyAutoDeleteQueuesWithTransientConsumer() } [Fact] - public void TestExchangeRecovery() + public async Task TestExchangeRecovery() { - string x = "dotnet-client.test.recovery.x1"; - DeclareNonDurableExchange(_model, x); + string x = DeclareNonDurableExchange(_model); CloseAndWaitForRecovery(); - AssertExchangeRecovery(_model, x); + await AssertExchangeRecoveryAsync(_model, x); _model.ExchangeDelete(x); } [Fact] - public void TestExchangeRecoveryWithNoWait() + public async Task TestExchangeRecoveryWithNoWait() { - string x = "dotnet-client.test.recovery.x1-nowait"; - DeclareNonDurableExchangeNoWait(_model, x); + string x = DeclareNonDurableExchangeNoWait(_model); CloseAndWaitForRecovery(); - AssertExchangeRecovery(_model, x); + await AssertExchangeRecoveryAsync(_model, x); _model.ExchangeDelete(x); } @@ -604,7 +609,7 @@ public void TestQueueRecoveryWithManyQueues() // rabbitmq/rabbitmq-dotnet-client#43 [Fact] - public void TestClientNamedTransientAutoDeleteQueueAndBindingRecovery() + public async Task TestClientNamedTransientAutoDeleteQueueAndBindingRecovery() { string q = Guid.NewGuid().ToString(); string x = "tmp-fanout"; @@ -614,7 +619,7 @@ public void TestClientNamedTransientAutoDeleteQueueAndBindingRecovery() ch.ExchangeDeclare(exchange: x, type: "fanout"); ch.QueueDeclare(queue: q, durable: false, exclusive: false, autoDelete: true, arguments: null); ch.QueueBind(queue: q, exchange: x, routingKey: ""); - RestartServerAndWaitForRecovery(); + await RestartServerAndWaitForRecoveryAsync(); Assert.True(ch.IsOpen); ch.ConfirmSelect(); ch.QueuePurge(q); @@ -629,7 +634,7 @@ public void TestClientNamedTransientAutoDeleteQueueAndBindingRecovery() // rabbitmq/rabbitmq-dotnet-client#43 [Fact] - public void TestServerNamedTransientAutoDeleteQueueAndBindingRecovery() + public async Task TestServerNamedTransientAutoDeleteQueueAndBindingRecovery() { string x = "tmp-fanout"; IModel ch = _conn.CreateModel(); @@ -646,7 +651,7 @@ public void TestServerNamedTransientAutoDeleteQueueAndBindingRecovery() latch.Set(); }; ch.QueueBind(queue: nameBefore, exchange: x, routingKey: ""); - RestartServerAndWaitForRecovery(); + await RestartServerAndWaitForRecoveryAsync(); Wait(latch); Assert.True(ch.IsOpen); Assert.NotEqual(nameBefore, nameAfter); @@ -665,11 +670,9 @@ public void TestRecoveryEventHandlersOnChannel() { int counter = 0; ((AutorecoveringModel)_model).Recovery += (source, ea) => Interlocked.Increment(ref counter); - CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); Assert.True(_conn.IsOpen); - Assert.True(counter >= 1); } @@ -678,13 +681,11 @@ public void TestRecoveryEventHandlersOnConnection() { int counter = 0; ((AutorecoveringConnection)_conn).RecoverySucceeded += (source, ea) => Interlocked.Increment(ref counter); - CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); Assert.True(_conn.IsOpen); - Assert.True(counter >= 3); } @@ -693,13 +694,11 @@ public void TestRecoveryEventHandlersOnModel() { int counter = 0; ((AutorecoveringModel)_model).Recovery += (source, ea) => Interlocked.Increment(ref counter); - CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); Assert.True(_model.IsOpen); - Assert.True(counter >= 3); } @@ -774,7 +773,7 @@ public void TestShutdownEventHandlersRecoveryOnConnection() } [Fact] - public void TestShutdownEventHandlersRecoveryOnConnectionAfterDelayedServerRestart() + public async Task TestShutdownEventHandlersRecoveryOnConnectionAfterDelayedServerRestart() { int counter = 0; _conn.ConnectionShutdown += (c, args) => Interlocked.Increment(ref counter); @@ -786,8 +785,8 @@ public void TestShutdownEventHandlersRecoveryOnConnectionAfterDelayedServerResta try { StopRabbitMQ(); - Console.WriteLine("Stopped RabbitMQ. About to sleep for multiple recovery intervals..."); - Thread.Sleep(7000); + _output.WriteLine("Stopped RabbitMQ. About to sleep for multiple recovery intervals..."); + await Task.Delay(7000); } finally { @@ -1009,22 +1008,22 @@ public void TestThatDeletedQueuesDontReappearOnRecovery() } [Fact] - public void TestUnblockedListenersRecovery() + public async Task TestUnblockedListenersRecovery() { var latch = new ManualResetEventSlim(false); _conn.ConnectionUnblocked += (source, ea) => latch.Set(); CloseAndWaitForRecovery(); CloseAndWaitForRecovery(); - Block(); + await BlockAsync(); Unblock(); Wait(latch); } - internal void AssertExchangeRecovery(IModel m, string x) + internal Task AssertExchangeRecoveryAsync(IModel m, string x) { m.ConfirmSelect(); - WithTemporaryNonExclusiveQueue(m, (_, q) => + return WithTemporaryNonExclusiveQueueAsync(m, (_, q) => { string rk = "routing-key"; m.QueueBind(q, x, rk); @@ -1032,15 +1031,11 @@ internal void AssertExchangeRecovery(IModel m, string x) Assert.True(WaitForConfirms(m)); m.ExchangeDeclarePassive(x); + return Task.CompletedTask; }); } - internal void AssertQueueRecovery(IModel m, string q) - { - AssertQueueRecovery(m, q, true); - } - - internal void AssertQueueRecovery(IModel m, string q, bool exclusive) + internal void AssertQueueRecovery(IModel m, string q, bool exclusive = true) { m.ConfirmSelect(); m.QueueDeclarePassive(q); @@ -1069,18 +1064,13 @@ internal void CloseAndWaitForRecovery() internal void CloseAndWaitForRecovery(AutorecoveringConnection conn) { + Stopwatch timer = Stopwatch.StartNew(); ManualResetEventSlim sl = PrepareForShutdown(conn); ManualResetEventSlim rl = PrepareForRecovery(conn); CloseConnection(conn); Wait(sl); Wait(rl); - } - - internal void CloseAndWaitForShutdown(AutorecoveringConnection conn) - { - ManualResetEventSlim sl = PrepareForShutdown(conn); - CloseConnection(conn); - Wait(sl); + _output.WriteLine($"Shutdown and recovered RabbitMQ in {timer.ElapsedMilliseconds}ms"); } internal ManualResetEventSlim PrepareForRecovery(IConnection conn) @@ -1103,35 +1093,16 @@ internal static ManualResetEventSlim PrepareForShutdown(IConnection conn) return latch; } - internal void RestartServerAndWaitForRecovery() - { - RestartServerAndWaitForRecovery((AutorecoveringConnection)_conn); - } - - internal void RestartServerAndWaitForRecovery(AutorecoveringConnection conn) + internal async Task RestartServerAndWaitForRecoveryAsync() { + AutorecoveringConnection conn = (AutorecoveringConnection)_conn; ManualResetEventSlim sl = PrepareForShutdown(conn); ManualResetEventSlim rl = PrepareForRecovery(conn); - RestartRabbitMQ(); + await RestartRabbitMQAsync(); Wait(sl); Wait(rl); } - internal void WaitForRecovery() - { - Wait(PrepareForRecovery((AutorecoveringConnection)_conn)); - } - - internal void WaitForRecovery(AutorecoveringConnection conn) - { - Wait(PrepareForRecovery(conn)); - } - - internal void WaitForShutdown() - { - Wait(PrepareForShutdown(_conn)); - } - internal void WaitForShutdown(IConnection conn) { Wait(PrepareForShutdown(conn)); diff --git a/projects/Unit/TestConsumerCancelNotify.cs b/projects/Unit/TestConsumerCancelNotify.cs index 75dbfddcf8..4a1362c765 100644 --- a/projects/Unit/TestConsumerCancelNotify.cs +++ b/projects/Unit/TestConsumerCancelNotify.cs @@ -66,6 +66,7 @@ public void TestConsumerCancelEvent() [Fact] public void TestCorrectConsumerTag() { + ManualResetEventSlim manualResetEventSlim = new ManualResetEventSlim(); string q1 = GenerateQueueName(); string q2 = GenerateQueueName(); @@ -79,15 +80,12 @@ public void TestCorrectConsumerTag() string notifiedConsumerTag = null; consumer.ConsumerCancelled += (sender, args) => { - lock (lockObject) - { notifiedConsumerTag = args.ConsumerTags.First(); - Monitor.PulseAll(lockObject); - } + manualResetEventSlim.Set(); }; _model.QueueDelete(q1); - WaitOn(lockObject); + Assert.True(manualResetEventSlim.Wait(TimingFixture.TestTimeout)); Assert.Equal(consumerTag1, notifiedConsumerTag); _model.QueueDelete(q2); @@ -95,12 +93,12 @@ public void TestCorrectConsumerTag() private void TestConsumerCancel(string queue, bool EventMode, ref bool notified) { + ManualResetEventSlim manualResetEventSlim = new ManualResetEventSlim(); _model.QueueDeclare(queue, false, true, false, null); - IBasicConsumer consumer = new CancelNotificationConsumer(_model, this, EventMode); + IBasicConsumer consumer = new CancelNotificationConsumer(_model, this, EventMode, manualResetEventSlim); string actualConsumerTag = _model.BasicConsume(queue, false, consumer); - _model.QueueDelete(queue); - WaitOn(lockObject); + Assert.True(manualResetEventSlim.Wait(TimingFixture.TestTimeout)); Assert.True(notified); Assert.Equal(actualConsumerTag, consumerTag); } @@ -109,12 +107,14 @@ private class CancelNotificationConsumer : DefaultBasicConsumer { private readonly TestConsumerCancelNotify _testClass; private readonly bool _eventMode; + private readonly ManualResetEventSlim _manualResetEventSlim; - public CancelNotificationConsumer(IModel model, TestConsumerCancelNotify tc, bool EventMode) + public CancelNotificationConsumer(IModel model, TestConsumerCancelNotify tc, bool EventMode, ManualResetEventSlim manualResetEventSlim) : base(model) { _testClass = tc; _eventMode = EventMode; + _manualResetEventSlim = manualResetEventSlim; if (EventMode) { ConsumerCancelled += Cancelled; @@ -125,24 +125,19 @@ public override void HandleBasicCancel(string consumerTag) { if (!_eventMode) { - lock (_testClass.lockObject) - { - _testClass.notifiedCallback = true; - _testClass.consumerTag = consumerTag; - Monitor.PulseAll(_testClass.lockObject); - } + _testClass.notifiedCallback = true; + _testClass.consumerTag = consumerTag; + _manualResetEventSlim.Set(); } + base.HandleBasicCancel(consumerTag); } private void Cancelled(object sender, ConsumerEventArgs arg) { - lock (_testClass.lockObject) - { _testClass.notifiedEvent = true; _testClass.consumerTag = arg.ConsumerTags[0]; - Monitor.PulseAll(_testClass.lockObject); - } + _manualResetEventSlim.Set(); } } } diff --git a/projects/Unit/TestConsumerExceptions.cs b/projects/Unit/TestConsumerExceptions.cs index 63b91a8408..c60dc5cc10 100644 --- a/projects/Unit/TestConsumerExceptions.cs +++ b/projects/Unit/TestConsumerExceptions.cs @@ -109,22 +109,17 @@ public override void HandleBasicCancelOk(string consumerTag) protected void TestExceptionHandlingWith(IBasicConsumer consumer, Action action) { - object o = new object(); - bool notified = false; + ManualResetEventSlim manualResetEventSlim = new ManualResetEventSlim(); string q = _model.QueueDeclare(); - _model.CallbackException += (m, evt) => { - notified = true; - Monitor.PulseAll(o); + manualResetEventSlim.Set(); }; string tag = _model.BasicConsume(q, true, consumer); action(_model, q, consumer, tag); - WaitOn(o); - - Assert.True(notified); + Assert.True(manualResetEventSlim.Wait(TimingFixture.TestTimeout)); } public TestConsumerExceptions(ITestOutputHelper output) : base(output) diff --git a/projects/Unit/TestConsumerOperationDispatch.cs b/projects/Unit/TestConsumerOperationDispatch.cs index b48e9615a8..d18e26298d 100644 --- a/projects/Unit/TestConsumerOperationDispatch.cs +++ b/projects/Unit/TestConsumerOperationDispatch.cs @@ -185,7 +185,7 @@ public ShutdownLatchConsumer(ManualResetEventSlim latch, ManualResetEventSlim du public override void HandleModelShutdown(object model, ShutdownEventArgs reason) { // keep track of duplicates - if (Latch.Wait(0)) + if (Latch.IsSet) { DuplicateLatch.Set(); } @@ -207,7 +207,7 @@ public void TestModelShutdownHandler() _model.BasicConsume(queue: q, autoAck: true, consumer: c); _model.Close(); Wait(latch, TimeSpan.FromSeconds(5)); - Assert.False(duplicateLatch.Wait(TimeSpan.FromSeconds(5)), + Assert.False(duplicateLatch.Wait(TimeSpan.FromSeconds(1)), "event handler fired more than once"); } } diff --git a/projects/Unit/TestContentHeaderCodec.cs b/projects/Unit/TestContentHeaderCodec.cs index 5700f550a0..a3cfcdf6c1 100644 --- a/projects/Unit/TestContentHeaderCodec.cs +++ b/projects/Unit/TestContentHeaderCodec.cs @@ -33,12 +33,20 @@ using System.Collections.Generic; using Xunit; +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { public class TestContentHeaderCodec { + protected readonly ITestOutputHelper _output; + + public TestContentHeaderCodec(ITestOutputHelper output) + { + _output = output; + } + private void Check(ReadOnlyMemory actual, ReadOnlyMemory expected) { try @@ -47,12 +55,12 @@ private void Check(ReadOnlyMemory actual, ReadOnlyMemory expected) } catch { - Console.WriteLine(); - Console.WriteLine("EXPECTED =================================================="); - DebugUtil.Dump(expected.ToArray(), Console.Out); - Console.WriteLine("ACTUAL ===================================================="); - DebugUtil.Dump(actual.ToArray(), Console.Out); - Console.WriteLine("==========================================================="); + _output.WriteLine(""); + _output.WriteLine("EXPECTED =================================================="); + DebugUtil.Dump(expected.ToArray(), _output); + _output.WriteLine("ACTUAL ===================================================="); + DebugUtil.Dump(actual.ToArray(), _output); + _output.WriteLine("==========================================================="); throw; } } diff --git a/projects/Unit/TestEventingConsumer.cs b/projects/Unit/TestEventingConsumer.cs index 601421225f..95b3b2fe09 100644 --- a/projects/Unit/TestEventingConsumer.cs +++ b/projects/Unit/TestEventingConsumer.cs @@ -85,7 +85,7 @@ public void TestEventingConsumerRegistrationEvents() public void TestEventingConsumerDeliveryEvents() { string q = _model.QueueDeclare(); - object o = new object(); + ManualResetEventSlim manualResetEventSlim = new ManualResetEventSlim(false); bool receivedInvoked = false; object receivedSender = null; @@ -95,14 +95,14 @@ public void TestEventingConsumerDeliveryEvents() { receivedInvoked = true; receivedSender = s; - - Monitor.PulseAll(o); + manualResetEventSlim.Set(); }; _model.BasicConsume(q, true, ec); _model.BasicPublish("", q, _encoding.GetBytes("msg")); - WaitOn(o); + Assert.True(manualResetEventSlim.Wait(TimingFixture.TestTimeout)); + manualResetEventSlim.Reset(); Assert.True(receivedInvoked); Assert.NotNull(receivedSender); Assert.Equal(ec, receivedSender); @@ -115,12 +115,11 @@ public void TestEventingConsumerDeliveryEvents() { shutdownInvoked = true; shutdownSender = s; - - Monitor.PulseAll(o); + manualResetEventSlim.Set(); }; _model.Close(); - WaitOn(o); + Assert.True(manualResetEventSlim.Wait(TimingFixture.TestTimeout)); Assert.True(shutdownInvoked); Assert.NotNull(shutdownSender); diff --git a/projects/Unit/TestFieldTableFormatting.cs b/projects/Unit/TestFieldTableFormatting.cs index dd4ecd6016..8d2cf75eb5 100644 --- a/projects/Unit/TestFieldTableFormatting.cs +++ b/projects/Unit/TestFieldTableFormatting.cs @@ -36,11 +36,16 @@ using RabbitMQ.Client.Impl; using Xunit; +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { public class TestFieldTableFormatting : WireFormattingFixture { + public TestFieldTableFormatting(ITestOutputHelper output) : base(output) + { + } + [Fact] public void TestStandardTypes() { diff --git a/projects/Unit/TestFieldTableFormattingGeneric.cs b/projects/Unit/TestFieldTableFormattingGeneric.cs index 640e68e03f..553179cf6c 100644 --- a/projects/Unit/TestFieldTableFormattingGeneric.cs +++ b/projects/Unit/TestFieldTableFormattingGeneric.cs @@ -37,12 +37,17 @@ using RabbitMQ.Client.Impl; using Xunit; +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { public class TestFieldTableFormattingGeneric : WireFormattingFixture { + public TestFieldTableFormattingGeneric(ITestOutputHelper output) : base(output) + { + } + [Fact] public void TestStandardTypes() { diff --git a/projects/Unit/TestFrameFormatting.cs b/projects/Unit/TestFrameFormatting.cs index 3e6f8a23f8..3cb816d1f1 100644 --- a/projects/Unit/TestFrameFormatting.cs +++ b/projects/Unit/TestFrameFormatting.cs @@ -34,11 +34,16 @@ using RabbitMQ.Client.Framing.Impl; using Xunit; +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { public class TestFrameFormatting : WireFormattingFixture { + public TestFrameFormatting(ITestOutputHelper output) : base(output) + { + } + [Fact] public void HeartbeatFrame() { diff --git a/projects/Unit/TestHeartbeats.cs b/projects/Unit/TestHeartbeats.cs index dad47f3742..221628612e 100644 --- a/projects/Unit/TestHeartbeats.cs +++ b/projects/Unit/TestHeartbeats.cs @@ -149,10 +149,10 @@ private void CheckInitiator(ShutdownEventArgs evt) { if (InitiatedByPeerOrLibrary(evt)) { - Console.WriteLine(((Exception)evt.Cause).StackTrace); + _output.WriteLine(((Exception)evt.Cause).StackTrace); string s = string.Format("Shutdown: {0}, initiated by: {1}", evt, evt.Initiator); - Console.WriteLine(s); + _output.WriteLine(s); Assert.True(false, s); } } @@ -169,7 +169,7 @@ private bool LongRunningTestsEnabled() private void SleepFor(int t) { - Console.WriteLine("Testing heartbeats, sleeping for {0} seconds", t); + _output.WriteLine("Testing heartbeats, sleeping for {0} seconds", t); Thread.Sleep(t * 1000); } } diff --git a/projects/Unit/TestInvalidAck.cs b/projects/Unit/TestInvalidAck.cs index f006bfed69..23ae60ad16 100644 --- a/projects/Unit/TestInvalidAck.cs +++ b/projects/Unit/TestInvalidAck.cs @@ -46,18 +46,18 @@ public TestInvalidAck(ITestOutputHelper output) : base(output) [Fact] public void TestAckWithUnknownConsumerTagAndMultipleFalse() { - object o = new object(); + ManualResetEventSlim manualResetEventSlim = new ManualResetEventSlim(); bool shutdownFired = false; ShutdownEventArgs shutdownArgs = null; _model.ModelShutdown += (s, args) => { shutdownFired = true; shutdownArgs = args; - Monitor.PulseAll(o); + manualResetEventSlim.Set(); }; _model.BasicAck(123456, false); - WaitOn(o); + Assert.True(manualResetEventSlim.Wait(TimingFixture.TestTimeout)); Assert.True(shutdownFired); AssertPreconditionFailed(shutdownArgs); } diff --git a/projects/Unit/TestMainLoop.cs b/projects/Unit/TestMainLoop.cs index 23f9edee9b..e7234e8cd0 100644 --- a/projects/Unit/TestMainLoop.cs +++ b/projects/Unit/TestMainLoop.cs @@ -68,7 +68,7 @@ public void TestCloseWithFaultyConsumer() ConnectionFactory connFactory = new ConnectionFactory(); IConnection c = connFactory.CreateConnection(); IModel m = _conn.CreateModel(); - object o = new object(); + ManualResetEventSlim manualResetEventSlim = new ManualResetEventSlim(false); string q = GenerateQueueName(); m.QueueDeclare(q, false, false, false, null); @@ -77,12 +77,11 @@ public void TestCloseWithFaultyConsumer() { ea = evt; c.Close(); - Monitor.PulseAll(o); + manualResetEventSlim.Set(); }; m.BasicConsume(q, true, new FaultyConsumer(_model)); m.BasicPublish("", q, _encoding.GetBytes("message")); - WaitOn(o); - + Assert.True(manualResetEventSlim.Wait(TimingFixture.TestTimeout)); Assert.NotNull(ea); Assert.False(c.IsOpen); Assert.Equal(200, c.CloseReason.ReplyCode); diff --git a/projects/Unit/TestMethodArgumentCodec.cs b/projects/Unit/TestMethodArgumentCodec.cs index c4d8ccfdb3..4d5e424b3b 100644 --- a/projects/Unit/TestMethodArgumentCodec.cs +++ b/projects/Unit/TestMethodArgumentCodec.cs @@ -36,12 +36,20 @@ using RabbitMQ.Client.Impl; using Xunit; +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { public class TestMethodArgumentCodec { + protected readonly ITestOutputHelper _output; + + public TestMethodArgumentCodec(ITestOutputHelper output) + { + _output = output; + } + private void Check(byte[] actual, byte[] expected) { try @@ -50,12 +58,12 @@ private void Check(byte[] actual, byte[] expected) } catch { - Console.WriteLine(); - Console.WriteLine("EXPECTED =================================================="); - DebugUtil.Dump(expected, Console.Out); - Console.WriteLine("ACTUAL ===================================================="); - DebugUtil.Dump(actual, Console.Out); - Console.WriteLine("==========================================================="); + _output.WriteLine(""); + _output.WriteLine("EXPECTED =================================================="); + DebugUtil.Dump(expected, _output); + _output.WriteLine("ACTUAL ===================================================="); + DebugUtil.Dump(actual, _output); + _output.WriteLine("==========================================================="); throw; } } diff --git a/projects/Unit/TestNetworkByteOrderSerialization.cs b/projects/Unit/TestNetworkByteOrderSerialization.cs index 6b29850613..e9b1d4841a 100644 --- a/projects/Unit/TestNetworkByteOrderSerialization.cs +++ b/projects/Unit/TestNetworkByteOrderSerialization.cs @@ -34,11 +34,19 @@ using RabbitMQ.Util; using Xunit; +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { public class TestNetworkByteOrderSerialization { + protected readonly ITestOutputHelper _output; + + public TestNetworkByteOrderSerialization(ITestOutputHelper output) + { + _output = output; + } + private void Check(byte[] actual, byte[] expected) { try @@ -47,12 +55,12 @@ private void Check(byte[] actual, byte[] expected) } catch { - Console.WriteLine(); - Console.WriteLine("EXPECTED =================================================="); - DebugUtil.Dump(expected, Console.Out); - Console.WriteLine("ACTUAL ===================================================="); - DebugUtil.Dump(actual, Console.Out); - Console.WriteLine("==========================================================="); + _output.WriteLine(""); + _output.WriteLine("EXPECTED =================================================="); + DebugUtil.Dump(expected, _output); + _output.WriteLine("ACTUAL ===================================================="); + DebugUtil.Dump(actual, _output); + _output.WriteLine("==========================================================="); throw; } } diff --git a/projects/Unit/TestPublishSharedModel.cs b/projects/Unit/TestPublishSharedModel.cs index 1fcf6ad017..9d49acdc7b 100644 --- a/projects/Unit/TestPublishSharedModel.cs +++ b/projects/Unit/TestPublishSharedModel.cs @@ -77,17 +77,14 @@ public async Task MultiThreadPublishOnSharedModel() model.QueueBind(QueueName, ExchangeName.Value, PublishKey.Value, null); // Act - var pubTask = Task.Run(() => NewFunction(model)); - var pubTask2 = Task.Run(() => NewFunction(model)); - - await Task.WhenAll(pubTask, pubTask2); + await Task.WhenAll(NewFunction(model), NewFunction(model)); } } // Assert Assert.Null(_raisedException); - void NewFunction(IModel model) + async Task NewFunction(IModel model) { try { @@ -98,7 +95,7 @@ void NewFunction(IModel model) model.BasicPublish(ExchangeName, PublishKey, _body, false); } - Thread.Sleep(1); + await Task.Delay(1); } } catch (Exception e) diff --git a/projects/Unit/TestUpdateSecret.cs b/projects/Unit/TestUpdateSecret.cs index 329e50446d..7d1027885c 100644 --- a/projects/Unit/TestUpdateSecret.cs +++ b/projects/Unit/TestUpdateSecret.cs @@ -49,7 +49,7 @@ public void TestUpdatingConnectionSecret() { if (!RabbitMQ380OrHigher()) { - Console.WriteLine("Not connected to RabbitMQ 3.8 or higher, skipping test"); + _output.WriteLine("Not connected to RabbitMQ 3.8 or higher, skipping test"); return; } diff --git a/projects/Unit/WireFormattingFixture.cs b/projects/Unit/WireFormattingFixture.cs index 0294d5981d..553d1c6ae5 100644 --- a/projects/Unit/WireFormattingFixture.cs +++ b/projects/Unit/WireFormattingFixture.cs @@ -32,11 +32,19 @@ using System; using Xunit; +using Xunit.Abstractions; namespace RabbitMQ.Client.Unit { public class WireFormattingFixture { + protected readonly ITestOutputHelper _output; + + public WireFormattingFixture(ITestOutputHelper output) + { + _output = output; + } + public void Check(byte[] actual, byte[] expected) { try @@ -45,12 +53,12 @@ public void Check(byte[] actual, byte[] expected) } catch { - Console.WriteLine(); - Console.WriteLine("EXPECTED =================================================="); - DebugUtil.Dump(expected, Console.Out); - Console.WriteLine("ACTUAL ===================================================="); - DebugUtil.Dump(actual, Console.Out); - Console.WriteLine("==========================================================="); + _output.WriteLine(""); + _output.WriteLine("EXPECTED =================================================="); + DebugUtil.Dump(expected, _output); + _output.WriteLine("ACTUAL ===================================================="); + DebugUtil.Dump(actual, _output); + _output.WriteLine("==========================================================="); throw; } } From 348d88abfa8b187952226a69a0f05a3ad6044808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=C3=A1n=20J=C3=B6kull=20Sigur=C3=B0arson?= Date: Mon, 25 Apr 2022 15:45:32 +0000 Subject: [PATCH 3/4] More test cleanups. --- projects/Unit/Fixtures.cs | 5 ++ projects/Unit/TestConnectionRecovery.cs | 79 +++++++++++-------------- 2 files changed, 40 insertions(+), 44 deletions(-) diff --git a/projects/Unit/Fixtures.cs b/projects/Unit/Fixtures.cs index f7b3ceaffa..470c6300be 100644 --- a/projects/Unit/Fixtures.cs +++ b/projects/Unit/Fixtures.cs @@ -375,6 +375,11 @@ internal void Wait(ManualResetEventSlim latch, TimeSpan? timeSpan = null) Assert.True(latch.Wait(timeSpan ?? TimeSpan.FromSeconds(30)), "waiting on a latch timed out"); } + internal void Wait(CountdownEvent countdownEvent, TimeSpan? timeSpan = null) + { + Assert.True(countdownEvent.Wait(timeSpan ?? TimeSpan.FromSeconds(30)), "waiting on a latch timed out"); + } + // // TLS // diff --git a/projects/Unit/TestConnectionRecovery.cs b/projects/Unit/TestConnectionRecovery.cs index 1240d5af6c..7052b606bd 100644 --- a/projects/Unit/TestConnectionRecovery.cs +++ b/projects/Unit/TestConnectionRecovery.cs @@ -96,13 +96,13 @@ public void TestBasicAckAfterChannelRecovery() _model.BasicQos(0, 1, false); string consumerTag = _model.BasicConsume(queueName, false, cons); - ManualResetEventSlim sl = PrepareForShutdown(_conn); - ManualResetEventSlim rl = PrepareForRecovery(_conn); + CountdownEvent countdownEvent = new CountdownEvent(2); + PrepareForShutdown(_conn, countdownEvent); + PrepareForRecovery(_conn, countdownEvent); PublishMessagesWhileClosingConn(queueName); - Wait(sl); - Wait(rl); + Wait(countdownEvent); Wait(allMessagesSeenLatch); } @@ -118,13 +118,13 @@ public void TestBasicNackAfterChannelRecovery() _model.BasicQos(0, 1, false); string consumerTag = _model.BasicConsume(queueName, false, cons); - ManualResetEventSlim sl = PrepareForShutdown(_conn); - ManualResetEventSlim rl = PrepareForRecovery(_conn); + CountdownEvent countdownEvent = new CountdownEvent(2); + PrepareForShutdown(_conn, countdownEvent); + PrepareForRecovery(_conn, countdownEvent); PublishMessagesWhileClosingConn(queueName); - Wait(sl); - Wait(rl); + Wait(countdownEvent); Wait(allMessagesSeenLatch); } @@ -140,13 +140,13 @@ public void TestBasicRejectAfterChannelRecovery() _model.BasicQos(0, 1, false); string consumerTag = _model.BasicConsume(queueName, false, cons); - ManualResetEventSlim sl = PrepareForShutdown(_conn); - ManualResetEventSlim rl = PrepareForRecovery(_conn); + CountdownEvent countdownEvent = new CountdownEvent(2); + PrepareForShutdown(_conn, countdownEvent); + PrepareForRecovery(_conn, countdownEvent); PublishMessagesWhileClosingConn(queueName); - Wait(sl); - Wait(rl); + Wait(allMessagesSeenLatch); } @@ -777,8 +777,9 @@ public async Task TestShutdownEventHandlersRecoveryOnConnectionAfterDelayedServe { int counter = 0; _conn.ConnectionShutdown += (c, args) => Interlocked.Increment(ref counter); - ManualResetEventSlim shutdownLatch = PrepareForShutdown(_conn); - ManualResetEventSlim recoveryLatch = PrepareForRecovery((AutorecoveringConnection)_conn); + CountdownEvent countdownEvent = new CountdownEvent(2); + PrepareForShutdown(_conn, countdownEvent); + PrepareForRecovery((AutorecoveringConnection)_conn, countdownEvent); Assert.True(_conn.IsOpen); @@ -793,8 +794,7 @@ public async Task TestShutdownEventHandlersRecoveryOnConnectionAfterDelayedServe StartRabbitMQ(); } - Wait(shutdownLatch, TimeSpan.FromSeconds(30)); - Wait(recoveryLatch, TimeSpan.FromSeconds(30)); + Wait(countdownEvent, TimeSpan.FromSeconds(60)); Assert.True(_conn.IsOpen); Assert.True(counter >= 1); } @@ -1065,47 +1065,41 @@ internal void CloseAndWaitForRecovery() internal void CloseAndWaitForRecovery(AutorecoveringConnection conn) { Stopwatch timer = Stopwatch.StartNew(); - ManualResetEventSlim sl = PrepareForShutdown(conn); - ManualResetEventSlim rl = PrepareForRecovery(conn); + CountdownEvent countdownEvent = new CountdownEvent(2); + PrepareForShutdown(conn, countdownEvent); + PrepareForRecovery(conn, countdownEvent); CloseConnection(conn); - Wait(sl); - Wait(rl); + Wait(countdownEvent); _output.WriteLine($"Shutdown and recovered RabbitMQ in {timer.ElapsedMilliseconds}ms"); } - internal ManualResetEventSlim PrepareForRecovery(IConnection conn) + internal static void PrepareForRecovery(IConnection conn, CountdownEvent countdownEvent) { - var latch = new ManualResetEventSlim(false); - AutorecoveringConnection aconn = conn as AutorecoveringConnection; - aconn.RecoverySucceeded += (source, ea) => latch.Set(); - - return latch; + aconn.RecoverySucceeded += (source, ea) => countdownEvent.Signal(); } - internal static ManualResetEventSlim PrepareForShutdown(IConnection conn) + internal static void PrepareForShutdown(IConnection conn, CountdownEvent countdownEvent) { - var latch = new ManualResetEventSlim(false); - AutorecoveringConnection aconn = conn as AutorecoveringConnection; - aconn.ConnectionShutdown += (c, args) => latch.Set(); - - return latch; + aconn.ConnectionShutdown += (c, args) => countdownEvent.Signal(); } internal async Task RestartServerAndWaitForRecoveryAsync() { + CountdownEvent countdownEvent = new CountdownEvent(2); AutorecoveringConnection conn = (AutorecoveringConnection)_conn; - ManualResetEventSlim sl = PrepareForShutdown(conn); - ManualResetEventSlim rl = PrepareForRecovery(conn); + PrepareForShutdown(conn, countdownEvent); + PrepareForRecovery(conn, countdownEvent); await RestartRabbitMQAsync(); - Wait(sl); - Wait(rl); + Wait(countdownEvent); } internal void WaitForShutdown(IConnection conn) { - Wait(PrepareForShutdown(conn)); + CountdownEvent countdownEvent = new CountdownEvent(1); + PrepareForShutdown(conn, countdownEvent); + Wait(countdownEvent); } internal void PublishMessagesWhileClosingConn(string queueName) @@ -1165,11 +1159,11 @@ public override void PostHandleDelivery(ulong deliveryTag) } } - public class TestBasicConsumer : DefaultBasicConsumer + public abstract class TestBasicConsumer : DefaultBasicConsumer { private readonly ManualResetEventSlim _allMessagesSeenLatch; private readonly ushort _totalMessageCount; - private ushort _counter = 0; + private int _counter = 0; public TestBasicConsumer(IModel model, ushort totalMessageCount, ManualResetEventSlim allMessagesSeenLatch) : base(model) @@ -1192,17 +1186,14 @@ public override void HandleBasicDeliver(string consumerTag, } finally { - ++_counter; - if (_counter >= _totalMessageCount) + if (Interlocked.Increment(ref _counter) == _totalMessageCount) { _allMessagesSeenLatch.Set(); } } } - public virtual void PostHandleDelivery(ulong deliveryTag) - { - } + public abstract void PostHandleDelivery(ulong deliveryTag); } } } From 96f1acafcdc7f6f36231505eb59daa285835e43c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=C3=A1n=20J=C3=B6kull=20Sigur=C3=B0arson?= Date: Mon, 25 Apr 2022 15:55:15 +0000 Subject: [PATCH 4/4] Adding logging to troubleshoot tests. --- projects/Unit/TestConnectionRecovery.cs | 46 +++++++++++++++---------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/projects/Unit/TestConnectionRecovery.cs b/projects/Unit/TestConnectionRecovery.cs index 7052b606bd..030517930b 100644 --- a/projects/Unit/TestConnectionRecovery.cs +++ b/projects/Unit/TestConnectionRecovery.cs @@ -97,8 +97,8 @@ public void TestBasicAckAfterChannelRecovery() string consumerTag = _model.BasicConsume(queueName, false, cons); CountdownEvent countdownEvent = new CountdownEvent(2); - PrepareForShutdown(_conn, countdownEvent); - PrepareForRecovery(_conn, countdownEvent); + PrepareForShutdown(_conn, countdownEvent, _output); + PrepareForRecovery(_conn, countdownEvent, _output); PublishMessagesWhileClosingConn(queueName); @@ -119,8 +119,8 @@ public void TestBasicNackAfterChannelRecovery() string consumerTag = _model.BasicConsume(queueName, false, cons); CountdownEvent countdownEvent = new CountdownEvent(2); - PrepareForShutdown(_conn, countdownEvent); - PrepareForRecovery(_conn, countdownEvent); + PrepareForShutdown(_conn, countdownEvent, _output); + PrepareForRecovery(_conn, countdownEvent, _output); PublishMessagesWhileClosingConn(queueName); @@ -141,8 +141,8 @@ public void TestBasicRejectAfterChannelRecovery() string consumerTag = _model.BasicConsume(queueName, false, cons); CountdownEvent countdownEvent = new CountdownEvent(2); - PrepareForShutdown(_conn, countdownEvent); - PrepareForRecovery(_conn, countdownEvent); + PrepareForShutdown(_conn, countdownEvent, _output); + PrepareForRecovery(_conn, countdownEvent, _output); PublishMessagesWhileClosingConn(queueName); @@ -446,7 +446,7 @@ public void TestCreateModelOnClosedAutorecoveringConnectionDoesNotHang() try { c.Close(); - WaitForShutdown(c); + WaitForShutdown(c, _output); Assert.False(c.IsOpen); c.CreateModel(); Assert.True(false, "Expected an exception"); @@ -778,8 +778,8 @@ public async Task TestShutdownEventHandlersRecoveryOnConnectionAfterDelayedServe int counter = 0; _conn.ConnectionShutdown += (c, args) => Interlocked.Increment(ref counter); CountdownEvent countdownEvent = new CountdownEvent(2); - PrepareForShutdown(_conn, countdownEvent); - PrepareForRecovery((AutorecoveringConnection)_conn, countdownEvent); + PrepareForShutdown(_conn, countdownEvent, _output); + PrepareForRecovery((AutorecoveringConnection)_conn, countdownEvent, _output); Assert.True(_conn.IsOpen); @@ -1066,39 +1066,47 @@ internal void CloseAndWaitForRecovery(AutorecoveringConnection conn) { Stopwatch timer = Stopwatch.StartNew(); CountdownEvent countdownEvent = new CountdownEvent(2); - PrepareForShutdown(conn, countdownEvent); - PrepareForRecovery(conn, countdownEvent); + PrepareForShutdown(conn, countdownEvent, _output); + PrepareForRecovery(conn, countdownEvent, _output); CloseConnection(conn); Wait(countdownEvent); _output.WriteLine($"Shutdown and recovered RabbitMQ in {timer.ElapsedMilliseconds}ms"); } - internal static void PrepareForRecovery(IConnection conn, CountdownEvent countdownEvent) + internal static void PrepareForRecovery(IConnection conn, CountdownEvent countdownEvent, ITestOutputHelper testOutputHelper) { AutorecoveringConnection aconn = conn as AutorecoveringConnection; - aconn.RecoverySucceeded += (source, ea) => countdownEvent.Signal(); + aconn.RecoverySucceeded += (source, ea) => + { + testOutputHelper.WriteLine("Received recovery succeeded event."); + countdownEvent.Signal(); + }; } - internal static void PrepareForShutdown(IConnection conn, CountdownEvent countdownEvent) + internal static void PrepareForShutdown(IConnection conn, CountdownEvent countdownEvent, ITestOutputHelper testOutputHelper) { AutorecoveringConnection aconn = conn as AutorecoveringConnection; - aconn.ConnectionShutdown += (c, args) => countdownEvent.Signal(); + aconn.ConnectionShutdown += (c, args) => + { + testOutputHelper.WriteLine("Received connection shutdown event."); + countdownEvent.Signal(); + }; } internal async Task RestartServerAndWaitForRecoveryAsync() { CountdownEvent countdownEvent = new CountdownEvent(2); AutorecoveringConnection conn = (AutorecoveringConnection)_conn; - PrepareForShutdown(conn, countdownEvent); - PrepareForRecovery(conn, countdownEvent); + PrepareForShutdown(conn, countdownEvent, _output); + PrepareForRecovery(conn, countdownEvent, _output); await RestartRabbitMQAsync(); Wait(countdownEvent); } - internal void WaitForShutdown(IConnection conn) + internal void WaitForShutdown(IConnection conn, ITestOutputHelper testOutputHelper) { CountdownEvent countdownEvent = new CountdownEvent(1); - PrepareForShutdown(conn, countdownEvent); + PrepareForShutdown(conn, countdownEvent, testOutputHelper); Wait(countdownEvent); }