diff --git a/projects/Apigen/apigen/Apigen.cs b/projects/Apigen/apigen/Apigen.cs index f9460243c9..63b1bdbe46 100644 --- a/projects/Apigen/apigen/Apigen.cs +++ b/projects/Apigen/apigen/Apigen.cs @@ -855,9 +855,28 @@ public void EmitClassMethodImplementations(AmqpClass c) EmitLine(""); EmitLine(" public override void WriteArgumentsTo(ref Client.Impl.MethodArgumentWriter writer)"); EmitLine(" {"); + var lastWasBitClass = false; foreach (AmqpField f in m.m_Fields) { - EmitLine($" writer.Write{MangleClass(ResolveDomain(f.Domain))}(_{MangleMethod(f.Name)});"); + string mangleClass = MangleClass(ResolveDomain(f.Domain)); + if (mangleClass != "Bit") + { + if (lastWasBitClass) + { + EmitLine($" writer.EndBits();"); + lastWasBitClass = false; + } + } + else + { + lastWasBitClass = true; + } + + EmitLine($" writer.Write{mangleClass}(_{MangleMethod(f.Name)});"); + } + if (lastWasBitClass) + { + EmitLine($" writer.EndBits();"); } EmitLine(" }"); EmitLine(""); @@ -933,14 +952,14 @@ public void EmitClassMethodImplementations(AmqpClass c) public void EmitMethodArgumentReader() { - EmitLine(" internal override Client.Impl.MethodBase DecodeMethodFrom(ReadOnlyMemory memory)"); + EmitLine(" internal override Client.Impl.MethodBase DecodeMethodFrom(ReadOnlySpan span)"); EmitLine(" {"); - EmitLine(" ushort classId = Util.NetworkOrderDeserializer.ReadUInt16(memory.Span);"); - EmitLine(" ushort methodId = Util.NetworkOrderDeserializer.ReadUInt16(memory.Slice(2).Span);"); + EmitLine(" ushort classId = Util.NetworkOrderDeserializer.ReadUInt16(span);"); + EmitLine(" ushort methodId = Util.NetworkOrderDeserializer.ReadUInt16(span.Slice(2));"); EmitLine(" Client.Impl.MethodBase result = DecodeMethodFrom(classId, methodId);"); EmitLine(" if(result != null)"); EmitLine(" {"); - EmitLine(" Client.Impl.MethodArgumentReader reader = new Client.Impl.MethodArgumentReader(memory.Slice(4));"); + EmitLine(" Client.Impl.MethodArgumentReader reader = new Client.Impl.MethodArgumentReader(span.Slice(4));"); EmitLine(" result.ReadArgumentsFrom(ref reader);"); EmitLine(" return result;"); EmitLine(" }"); diff --git a/projects/RabbitMQ.Client/RabbitMQ.Client.csproj b/projects/RabbitMQ.Client/RabbitMQ.Client.csproj index 95ea87d43b..ebe5081ffb 100755 --- a/projects/RabbitMQ.Client/RabbitMQ.Client.csproj +++ b/projects/RabbitMQ.Client/RabbitMQ.Client.csproj @@ -26,6 +26,7 @@ minimal true ..\..\packages + true diff --git a/projects/RabbitMQ.Client/client/impl/Command.cs b/projects/RabbitMQ.Client/client/impl/Command.cs index 06b908813d..c270bc44f0 100644 --- a/projects/RabbitMQ.Client/client/impl/Command.cs +++ b/projects/RabbitMQ.Client/client/impl/Command.cs @@ -57,11 +57,6 @@ class Command : IDisposable private const int EmptyFrameSize = 8; private readonly bool _returnBufferOnDispose; - static Command() - { - CheckEmptyFrameSize(); - } - internal Command(MethodBase method) : this(method, null, null, false) { } @@ -80,23 +75,6 @@ public Command(MethodBase method, ContentHeaderBase header, ReadOnlyMemory internal MethodBase Method { get; private set; } - public static void CheckEmptyFrameSize() - { - var f = new EmptyOutboundFrame(); - byte[] b = new byte[f.GetMinimumBufferSize()]; - f.WriteTo(b); - long actualLength = f.ByteCount; - - if (EmptyFrameSize != actualLength) - { - string message = - string.Format("EmptyFrameSize is incorrect - defined as {0} where the computed value is in fact {1}.", - EmptyFrameSize, - actualLength); - throw new ProtocolViolationException(message); - } - } - internal void Transmit(int channelNumber, Connection connection) { connection.WriteFrame(new MethodOutboundFrame(channelNumber, Method)); diff --git a/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs b/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs index bfca721d89..cc7c1a11a8 100644 --- a/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs +++ b/projects/RabbitMQ.Client/client/impl/CommandAssembler.cs @@ -81,7 +81,7 @@ public Command HandleFrame(in InboundFrame f) { throw new UnexpectedFrameException(f.Type); } - m_method = m_protocol.DecodeMethodFrom(f.Payload); + m_method = m_protocol.DecodeMethodFrom(f.Payload.Span); m_state = m_method.HasContent ? AssemblyState.ExpectingContentHeader : AssemblyState.Complete; return CompletedCommand(); case AssemblyState.ExpectingContentHeader: @@ -89,8 +89,10 @@ public Command HandleFrame(in InboundFrame f) { throw new UnexpectedFrameException(f.Type); } - m_header = m_protocol.DecodeContentHeaderFrom(NetworkOrderDeserializer.ReadUInt16(f.Payload.Span)); - ulong totalBodyBytes = m_header.ReadFrom(f.Payload.Slice(2)); + + ReadOnlySpan span = f.Payload.Span; + m_header = m_protocol.DecodeContentHeaderFrom(NetworkOrderDeserializer.ReadUInt16(span)); + ulong totalBodyBytes = m_header.ReadFrom(span.Slice(2)); if (totalBodyBytes > MaxArrayOfBytesSize) { throw new UnexpectedFrameException(f.Type); diff --git a/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs b/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs index 946e48b61b..937f2f51e0 100644 --- a/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ContentHeaderBase.cs @@ -67,11 +67,11 @@ public virtual object Clone() /// /// Fill this instance from the given byte buffer stream. /// - internal ulong ReadFrom(ReadOnlyMemory memory) + internal ulong ReadFrom(ReadOnlySpan span) { // Skipping the first two bytes since they arent used (weight - not currently used) - ulong bodySize = NetworkOrderDeserializer.ReadUInt64(memory.Slice(2).Span); - ContentHeaderPropertyReader reader = new ContentHeaderPropertyReader(memory.Slice(10)); + ulong bodySize = NetworkOrderDeserializer.ReadUInt64(span.Slice(2)); + ContentHeaderPropertyReader reader = new ContentHeaderPropertyReader(span.Slice(10)); ReadPropertiesFrom(ref reader); return bodySize; } @@ -81,12 +81,12 @@ internal ulong ReadFrom(ReadOnlyMemory memory) private const ushort ZERO = 0; - internal int WriteTo(Memory memory, ulong bodySize) + internal int WriteTo(Span span, ulong bodySize) { - NetworkOrderSerializer.WriteUInt16(memory.Span, ZERO); // Weight - not used - NetworkOrderSerializer.WriteUInt64(memory.Slice(2).Span, bodySize); + NetworkOrderSerializer.WriteUInt16(span, ZERO); // Weight - not used + NetworkOrderSerializer.WriteUInt64(span.Slice(2), bodySize); - ContentHeaderPropertyWriter writer = new ContentHeaderPropertyWriter(memory.Slice(10)); + ContentHeaderPropertyWriter writer = new ContentHeaderPropertyWriter(span.Slice(10)); WritePropertiesTo(ref writer); return 10 + writer.Offset; } diff --git a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs index d8a7ea1671..3adb5e16a4 100644 --- a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs +++ b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyReader.cs @@ -45,26 +45,28 @@ namespace RabbitMQ.Client.Impl { - internal struct ContentHeaderPropertyReader + internal ref struct ContentHeaderPropertyReader { - private ushort m_bitCount; - private ushort m_flagWord; - private int _memoryOffset; - private readonly ReadOnlyMemory _memory; + private const int StartBitMask = 0b1000_0000_0000_0000; + private const int EndBitMask = 0b0000_0000_0000_0001; - public ContentHeaderPropertyReader(ReadOnlyMemory memory) - { - _memory = memory; - _memoryOffset = 0; - m_flagWord = 1; // just the continuation bit - m_bitCount = 15; // the correct position to force a m_flagWord read - } + private readonly ReadOnlySpan _span; + private int _offset; + private int _bitMask; + private int _bits; - public bool ContinuationBitSet + private ReadOnlySpan Span => _span.Slice(_offset); + + public ContentHeaderPropertyReader(ReadOnlySpan span) { - get { return (m_flagWord & 1) != 0; } + _span = span; + _offset = 0; + _bitMask = EndBitMask; // force a flag read + _bits = 1; // just the continuation bit } + private bool ContinuationBitSet => (_bits & EndBitMask) != 0; + public void FinishPresence() { if (ContinuationBitSet) @@ -78,82 +80,81 @@ public bool ReadBit() return ReadPresence(); } - public void ReadFlagWord() + private void ReadBits() { if (!ContinuationBitSet) { throw new MalformedFrameException("Attempted to read flag word when none advertised"); } - m_flagWord = NetworkOrderDeserializer.ReadUInt16(_memory.Slice(_memoryOffset).Span); - _memoryOffset += 2; - m_bitCount = 0; + _bits = NetworkOrderDeserializer.ReadUInt16(Span); + _offset += 2; + _bitMask = StartBitMask; } public uint ReadLong() { - uint result = NetworkOrderDeserializer.ReadUInt32(_memory.Slice(_memoryOffset).Span); - _memoryOffset += 4; + uint result = NetworkOrderDeserializer.ReadUInt32(Span); + _offset += 4; return result; } public ulong ReadLonglong() { - ulong result = NetworkOrderDeserializer.ReadUInt64(_memory.Slice(_memoryOffset).Span); - _memoryOffset += 8; + ulong result = NetworkOrderDeserializer.ReadUInt64(Span); + _offset += 8; return result; } public byte[] ReadLongstr() { - byte[] result = WireFormatting.ReadLongstr(_memory.Slice(_memoryOffset)); - _memoryOffset += 4 + result.Length; + byte[] result = WireFormatting.ReadLongstr(Span); + _offset += 4 + result.Length; return result; } public byte ReadOctet() { - return _memory.Span[_memoryOffset++]; + return _span[_offset++]; } public bool ReadPresence() { - if (m_bitCount == 15) + if (_bitMask == EndBitMask) { - ReadFlagWord(); + ReadBits(); } - int bit = 15 - m_bitCount; - bool result = (m_flagWord & (1 << bit)) != 0; - m_bitCount++; + bool result = (_bits & _bitMask) != 0; + _bitMask >>= 1; return result; } public ushort ReadShort() { - ushort result = NetworkOrderDeserializer.ReadUInt16(_memory.Slice(_memoryOffset).Span); - _memoryOffset += 2; + ushort result = NetworkOrderDeserializer.ReadUInt16(Span); + _offset += 2; return result; } public string ReadShortstr() { - string result = WireFormatting.ReadShortstr(_memory.Slice(_memoryOffset), out int bytesRead); - _memoryOffset += bytesRead; + string result = WireFormatting.ReadShortstr(Span, out int bytesRead); + _offset += bytesRead; return result; } /// A type of . public Dictionary ReadTable() { - Dictionary result = WireFormatting.ReadTable(_memory.Slice(_memoryOffset), out int bytesRead); - _memoryOffset += bytesRead; + Dictionary result = WireFormatting.ReadTable(Span, out int bytesRead); + _offset += bytesRead; return result; } public AmqpTimestamp ReadTimestamp() { - AmqpTimestamp result = WireFormatting.ReadTimestamp(_memory.Slice(_memoryOffset)); - _memoryOffset += 8; + AmqpTimestamp result = WireFormatting.ReadTimestamp(Span); + _offset += 8; return result; } } diff --git a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs index 871d4f500d..3d9c6fb16b 100644 --- a/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs +++ b/projects/RabbitMQ.Client/client/impl/ContentHeaderPropertyWriter.cs @@ -45,24 +45,31 @@ namespace RabbitMQ.Client.Impl { - struct ContentHeaderPropertyWriter + internal ref struct ContentHeaderPropertyWriter { - private int _bitCount; - private ushort _flagWord; - public int Offset { get; private set; } - public Memory Memory { get; private set; } + private const ushort StartBitMask = 0b1000_0000_0000_0000; + private const ushort EndBitMask = 0b0000_0000_0000_0001; - public ContentHeaderPropertyWriter(Memory memory) + private readonly Span _span; + private int _offset; + private ushort _bitAccumulator; + private ushort _bitMask; + + public int Offset => _offset; + + private Span Span => _span.Slice(_offset); + + public ContentHeaderPropertyWriter(Span span) { - Memory = memory; - _flagWord = 0; - _bitCount = 0; - Offset = 0; + _span = span; + _offset = 0; + _bitAccumulator = 0; + _bitMask = StartBitMask; } public void FinishPresence() { - EmitFlagWord(false); + WriteBits(); } public void WriteBit(bool bit) @@ -72,65 +79,67 @@ public void WriteBit(bool bit) public void WriteLong(uint val) { - Offset += WireFormatting.WriteLong(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteLong(Span, val); } public void WriteLonglong(ulong val) { - Offset += WireFormatting.WriteLonglong(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteLonglong(Span, val); } public void WriteLongstr(byte[] val) { - Offset += WireFormatting.WriteLongstr(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteLongstr(Span, val); } public void WriteOctet(byte val) { - Memory.Slice(Offset++).Span[0] = val; + _span[_offset++] = val; } public void WritePresence(bool present) { - if (_bitCount == 15) + if (_bitMask == EndBitMask) { - EmitFlagWord(true); + // Mark continuation + _bitAccumulator |= _bitMask; + WriteBits(); } if (present) { - int bit = 15 - _bitCount; - _flagWord = (ushort)(_flagWord | (1 << bit)); + _bitAccumulator |= _bitMask; } - _bitCount++; + + _bitMask >>= 1; } public void WriteShort(ushort val) { - Offset += WireFormatting.WriteShort(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteShort(Span, val); } public void WriteShortstr(string val) { - Offset += WireFormatting.WriteShortstr(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteShortstr(Span, val); } public void WriteTable(IDictionary val) { - Offset += WireFormatting.WriteTable(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteTable(Span, val); } public void WriteTimestamp(AmqpTimestamp val) { - Offset += WireFormatting.WriteTimestamp(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteTimestamp(Span, val); } - private void EmitFlagWord(bool continuationBit) + private void WriteBits() { - NetworkOrderSerializer.WriteUInt16(Memory.Slice(Offset).Span, (ushort)(continuationBit ? (_flagWord | 1) : _flagWord)); - Offset += 2; - _flagWord = 0; - _bitCount = 0; + NetworkOrderSerializer.WriteUInt16(Span, _bitAccumulator); + _offset += 2; + _bitMask = StartBitMask; + _bitAccumulator = 0; } } } diff --git a/projects/RabbitMQ.Client/client/impl/Frame.cs b/projects/RabbitMQ.Client/client/impl/Frame.cs index f1c2682a12..aba05bc9fb 100644 --- a/projects/RabbitMQ.Client/client/impl/Frame.cs +++ b/projects/RabbitMQ.Client/client/impl/Frame.cs @@ -45,7 +45,6 @@ using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using RabbitMQ.Client.Exceptions; -using RabbitMQ.Client.Framing; using RabbitMQ.Util; namespace RabbitMQ.Client.Impl @@ -67,13 +66,13 @@ internal override int GetMinimumPayloadBufferSize() return 2 + _header.GetRequiredBufferSize(); } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { // write protocol class id (2 bytes) - NetworkOrderSerializer.WriteUInt16(memory.Span, _header.ProtocolClassId); + NetworkOrderSerializer.WriteUInt16(span, _header.ProtocolClassId); // write header (X bytes) - int bytesWritten = _header.WriteTo(memory.Slice(2), (ulong)_bodyLength); - return 2 + bytesWritten; + int bytesWritten = _header.WriteTo(span.Slice(2), (ulong)_bodyLength); + return bytesWritten + 2; } } @@ -91,9 +90,9 @@ internal override int GetMinimumPayloadBufferSize() return _body.Length; } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { - _body.CopyTo(memory); + _body.Span.CopyTo(span); return _body.Length; } } @@ -113,13 +112,12 @@ internal override int GetMinimumPayloadBufferSize() return 4 + _method.GetRequiredBufferSize(); } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { - NetworkOrderSerializer.WriteUInt16(memory.Span, _method.ProtocolClassId); - NetworkOrderSerializer.WriteUInt16(memory.Slice(2).Span, _method.ProtocolMethodId); - var argWriter = new MethodArgumentWriter(memory.Slice(4)); + NetworkOrderSerializer.WriteUInt16(span, _method.ProtocolClassId); + NetworkOrderSerializer.WriteUInt16(span.Slice(2), _method.ProtocolMethodId); + var argWriter = new MethodArgumentWriter(span.Slice(4)); _method.WriteArgumentsTo(ref argWriter); - argWriter.Flush(); return 4 + argWriter.Offset; } } @@ -135,35 +133,43 @@ internal override int GetMinimumPayloadBufferSize() return 0; } - internal override int WritePayload(Memory memory) + internal override int WritePayload(Span span) { return 0; } } - abstract class OutboundFrame : Frame + internal abstract class OutboundFrame { - public int ByteCount { get; private set; } = 0; - public OutboundFrame(FrameType type, int channel) : base(type, channel) + public int Channel { get; } + public FrameType Type { get; } + + protected OutboundFrame(FrameType type, int channel) { + Type = type; + Channel = channel; } - internal void WriteTo(Memory memory) + internal void WriteTo(Span span) { - memory.Span[0] = (byte)Type; - NetworkOrderSerializer.WriteUInt16(memory.Slice(1).Span, (ushort)Channel); - int bytesWritten = WritePayload(memory.Slice(7)); - NetworkOrderSerializer.WriteUInt32(memory.Slice(3).Span, (uint)bytesWritten); - memory.Span[bytesWritten + 7] = Constants.FrameEnd; - ByteCount = bytesWritten + 8; + span[0] = (byte)Type; + NetworkOrderSerializer.WriteUInt16(span.Slice(1), (ushort)Channel); + int bytesWritten = WritePayload(span.Slice(7)); + NetworkOrderSerializer.WriteUInt32(span.Slice(3), (uint)bytesWritten); + span[bytesWritten + 7] = Constants.FrameEnd; } - internal abstract int WritePayload(Memory memory); + internal abstract int WritePayload(Span span); internal abstract int GetMinimumPayloadBufferSize(); internal int GetMinimumBufferSize() { return 8 + GetMinimumPayloadBufferSize(); } + + public override string ToString() + { + return $"(type={Type}, channel={Channel})"; + } } internal readonly struct InboundFrame : IDisposable @@ -298,53 +304,18 @@ public void Dispose() ArrayPool.Shared.Return(segment.Array); } } - - public override string ToString() - { - return $"(type={Type}, channel={Channel}, {Payload.Length} bytes of payload)"; - } - } - - class Frame - { - public Frame(FrameType type, int channel) - { - Type = type; - Channel = channel; - Payload = null; - } - - public Frame(FrameType type, int channel, ReadOnlyMemory payload) - { - Type = type; - Channel = channel; - Payload = payload; - } - - public int Channel { get; private set; } - - public ReadOnlyMemory Payload { get; private set; } - - public FrameType Type { get; private set; } public override string ToString() { - return string.Format("(type={0}, channel={1}, {2} bytes of payload)", - Type, - Channel, - Payload.Length.ToString()); + return $"(type={Type}, channel={Channel}, {Payload.Length} bytes of payload)"; } - - } - enum FrameType : int + internal enum FrameType : int { - FrameMethod = 1, - FrameHeader = 2, - FrameBody = 3, - FrameHeartbeat = 8, - FrameEnd = 206, - FrameMinSize = 4096 + FrameMethod = Constants.FrameMethod, + FrameHeader = Constants.FrameHeader, + FrameBody = Constants.FrameBody, + FrameHeartbeat = Constants.FrameHeartbeat } } diff --git a/projects/RabbitMQ.Client/client/impl/MainSession.cs b/projects/RabbitMQ.Client/client/impl/MainSession.cs index e440a04d3f..6b863643c5 100644 --- a/projects/RabbitMQ.Client/client/impl/MainSession.cs +++ b/projects/RabbitMQ.Client/client/impl/MainSession.cs @@ -84,7 +84,7 @@ public override void HandleFrame(in InboundFrame frame) if (!_closeServerInitiated && frame.IsMethod()) { - MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload); + MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload.Span); if ((method.ProtocolClassId == _closeClassId) && (method.ProtocolMethodId == _closeMethodId)) { diff --git a/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs b/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs index cc529e2156..6a737cb36c 100644 --- a/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs +++ b/projects/RabbitMQ.Client/client/impl/MethodArgumentReader.cs @@ -45,32 +45,34 @@ namespace RabbitMQ.Client.Impl { - internal struct MethodArgumentReader + internal ref struct MethodArgumentReader { - private int? _bit; + private readonly ReadOnlySpan _span; + private int _offset; + private int _bitMask; private int _bits; - public MethodArgumentReader(ReadOnlyMemory memory) + private ReadOnlySpan Span => _span.Slice(_offset); + + public MethodArgumentReader(ReadOnlySpan span) { - _memory = memory; - _memoryOffset = 0; + _span = span; + _offset = 0; + _bitMask = 0; _bits = 0; - _bit = null; } - private readonly ReadOnlyMemory _memory; - private int _memoryOffset; - public bool ReadBit() { - if (!_bit.HasValue) + int bit = _bitMask; + if (bit == 0) { - _bits = _memory.Span[_memoryOffset++]; - _bit = 0x01; + _bits = _span[_offset++]; + bit = 1; } - bool result = (_bits & _bit.Value) != 0; - _bit <<= 1; + bool result = (_bits & bit) != 0; + _bitMask = bit << 1; return result; } @@ -81,74 +83,56 @@ public byte[] ReadContent() public uint ReadLong() { - ClearBits(); - uint result = NetworkOrderDeserializer.ReadUInt32(_memory.Slice(_memoryOffset).Span); - _memoryOffset += 4; + uint result = NetworkOrderDeserializer.ReadUInt32(Span); + _offset += 4; return result; } public ulong ReadLonglong() { - ClearBits(); - ulong result = NetworkOrderDeserializer.ReadUInt64(_memory.Slice(_memoryOffset).Span); - _memoryOffset += 8; + ulong result = NetworkOrderDeserializer.ReadUInt64(Span); + _offset += 8; return result; } public byte[] ReadLongstr() { - ClearBits(); - byte[] result = WireFormatting.ReadLongstr(_memory.Slice(_memoryOffset)); - _memoryOffset += 4 + result.Length; + byte[] result = WireFormatting.ReadLongstr(Span); + _offset += 4 + result.Length; return result; } public byte ReadOctet() { - ClearBits(); - return _memory.Span[_memoryOffset++]; + return _span[_offset++]; } public ushort ReadShort() { - ClearBits(); - ushort result = NetworkOrderDeserializer.ReadUInt16(_memory.Slice(_memoryOffset).Span); - _memoryOffset += 2; + ushort result = NetworkOrderDeserializer.ReadUInt16(Span); + _offset += 2; return result; } public string ReadShortstr() { - ClearBits(); - string result = WireFormatting.ReadShortstr(_memory.Slice(_memoryOffset), out int bytesRead); - _memoryOffset += bytesRead; + string result = WireFormatting.ReadShortstr(Span, out int bytesRead); + _offset += bytesRead; return result; } public Dictionary ReadTable() { - ClearBits(); - Dictionary result = WireFormatting.ReadTable(_memory.Slice(_memoryOffset), out int bytesRead); - _memoryOffset += bytesRead; + Dictionary result = WireFormatting.ReadTable(Span, out int bytesRead); + _offset += bytesRead; return result; } public AmqpTimestamp ReadTimestamp() { - ClearBits(); - AmqpTimestamp result = WireFormatting.ReadTimestamp(_memory.Slice(_memoryOffset)); - _memoryOffset += 8; + AmqpTimestamp result = WireFormatting.ReadTimestamp(Span); + _offset += 8; return result; } - - private void ClearBits() - { - _bits = 0; - _bit = null; - } - - // TODO: Consider using NotImplementedException (?) - // This is a completely bizarre consequence of the way the - // Message.Transfer method is marked up in the XML spec. } } diff --git a/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs b/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs index b5ed324109..35c9a0ecdc 100644 --- a/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs +++ b/projects/RabbitMQ.Client/client/impl/MethodArgumentWriter.cs @@ -44,44 +44,42 @@ namespace RabbitMQ.Client.Impl { - struct MethodArgumentWriter + internal ref struct MethodArgumentWriter { - private byte _bitAccumulator; + private readonly Span _span; + private int _offset; + private int _bitAccumulator; private int _bitMask; - private bool _needBitFlush; - public int Offset { get; private set; } - public Memory Memory { get; private set; } - public MethodArgumentWriter(Memory memory) + public int Offset => _offset; + + private Span Span => _span.Slice(_offset); + + public MethodArgumentWriter(Span span) { - Memory = memory; - _needBitFlush = false; + _span = span; + _offset = 0; _bitAccumulator = 0; _bitMask = 1; - Offset = 0; - } - - public void Flush() - { - BitFlush(); } public void WriteBit(bool val) { - if (_bitMask > 0x80) - { - BitFlush(); - } if (val) { - // The cast below is safe, because the combination of - // the test against 0x80 above, and the action of - // BitFlush(), causes m_bitMask never to exceed 0x80 - // at the point the following statement executes. - _bitAccumulator = (byte)(_bitAccumulator | (byte)_bitMask); + _bitAccumulator |= _bitMask; } _bitMask <<= 1; - _needBitFlush = true; + } + + public void EndBits() + { + if (_bitMask > 1) + { + _span[_offset++] = (byte)_bitAccumulator; + _bitAccumulator = 0; + _bitMask = 1; + } } public void WriteContent(byte[] val) @@ -91,76 +89,47 @@ public void WriteContent(byte[] val) public void WriteLong(uint val) { - BitFlush(); - Offset += WireFormatting.WriteLong(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteLong(Span, val); } public void WriteLonglong(ulong val) { - BitFlush(); - Offset += WireFormatting.WriteLonglong(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteLonglong(Span, val); } public void WriteLongstr(byte[] val) { - BitFlush(); - Offset += WireFormatting.WriteLongstr(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteLongstr(Span, val); } public void WriteOctet(byte val) { - BitFlush(); - Memory.Slice(Offset++).Span[0] = val; + _span[_offset++] = val; } public void WriteShort(ushort val) { - BitFlush(); - Offset += WireFormatting.WriteShort(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteShort(Span, val); } public void WriteShortstr(string val) { - BitFlush(); - Offset += WireFormatting.WriteShortstr(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteShortstr(Span, val); } public void WriteTable(IDictionary val) { - BitFlush(); - Offset += WireFormatting.WriteTable(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteTable(Span, val); } public void WriteTable(IDictionary val) { - BitFlush(); - Offset += WireFormatting.WriteTable(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteTable(Span, val); } public void WriteTimestamp(AmqpTimestamp val) { - BitFlush(); - Offset += WireFormatting.WriteTimestamp(Memory.Slice(Offset), val); + _offset += WireFormatting.WriteTimestamp(Span, val); } - - private void BitFlush() - { - if (_needBitFlush) - { - Memory.Slice(Offset++).Span[0] = _bitAccumulator; - ResetBitAccumulator(); - } - } - - private void ResetBitAccumulator() - { - _needBitFlush = false; - _bitAccumulator = 0; - _bitMask = 1; - } - - // TODO: Consider using NotImplementedException (?) - // This is a completely bizarre consequence of the way the - // Message.Transfer method is marked up in the XML spec. } } diff --git a/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs b/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs index f4be4336aa..6ae4ed75c5 100644 --- a/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ProtocolBase.cs @@ -106,7 +106,7 @@ public void CreateConnectionClose(ushort reasonCode, } internal abstract ContentHeaderBase DecodeContentHeaderFrom(ushort classId); - internal abstract MethodBase DecodeMethodFrom(ReadOnlyMemory reader); + internal abstract MethodBase DecodeMethodFrom(ReadOnlySpan reader); public override bool Equals(object obj) { diff --git a/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs b/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs index 633426b133..ca5a1209f3 100644 --- a/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs +++ b/projects/RabbitMQ.Client/client/impl/QuiescingSession.cs @@ -59,7 +59,7 @@ public override void HandleFrame(in InboundFrame frame) { if (frame.IsMethod()) { - MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload); + MethodBase method = Connection.Protocol.DecodeMethodFrom(frame.Payload.Span); if ((method.ProtocolClassId == ClassConstants.Channel) && (method.ProtocolMethodId == ChannelMethodConstants.CloseOk)) { diff --git a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs index ed21562aa7..33c93413ee 100644 --- a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs +++ b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs @@ -252,8 +252,7 @@ public async Task WriteFrameImpl() { int bufferSize = frame.GetMinimumBufferSize(); byte[] memoryArray = ArrayPool.Shared.Rent(bufferSize); - Memory slice = new Memory(memoryArray, 0, bufferSize); - frame.WriteTo(slice); + frame.WriteTo(new Span(memoryArray, 0, bufferSize)); _writer.Write(memoryArray, 0, bufferSize); ArrayPool.Shared.Return(memoryArray); } diff --git a/projects/RabbitMQ.Client/client/impl/WireFormatting.cs b/projects/RabbitMQ.Client/client/impl/WireFormatting.cs index 259b3b13ad..d97262b607 100644 --- a/projects/RabbitMQ.Client/client/impl/WireFormatting.cs +++ b/projects/RabbitMQ.Client/client/impl/WireFormatting.cs @@ -41,7 +41,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Runtime.InteropServices; using System.Text; using RabbitMQ.Client.Exceptions; @@ -93,14 +92,14 @@ public static void DecimalToAmqp(decimal value, out byte scale, out int mantissa (((uint)bitRepresentation[0]) & 0x7FFFFFFF)); } - public static IList ReadArray(ReadOnlyMemory memory, out int bytesRead) + public static IList ReadArray(ReadOnlySpan span, out int bytesRead) { List array = new List(); - long arrayLength = NetworkOrderDeserializer.ReadUInt32(memory.Span); + long arrayLength = NetworkOrderDeserializer.ReadUInt32(span); bytesRead = 4; while (bytesRead - 4 < arrayLength) { - object value = ReadFieldValue(memory.Slice(bytesRead), out int fieldValueBytesRead); + object value = ReadFieldValue(span.Slice(bytesRead), out int fieldValueBytesRead); bytesRead += fieldValueBytesRead; array.Add(value); } @@ -108,97 +107,103 @@ public static IList ReadArray(ReadOnlyMemory memory, out int bytesRead) return array; } - public static decimal ReadDecimal(ReadOnlyMemory memory) + public static decimal ReadDecimal(ReadOnlySpan span) { - byte scale = memory.Span[0]; - uint unsignedMantissa = NetworkOrderDeserializer.ReadUInt32(memory.Slice(1).Span); + byte scale = span[0]; + uint unsignedMantissa = NetworkOrderDeserializer.ReadUInt32(span.Slice(1)); return AmqpToDecimal(scale, unsignedMantissa); } - public static object ReadFieldValue(ReadOnlyMemory memory, out int bytesRead) + public static object ReadFieldValue(ReadOnlySpan span, out int bytesRead) { bytesRead = 1; - ReadOnlyMemory slice = memory.Slice(1); - switch ((char)memory.Span[0]) + switch ((char)span[0]) { case 'S': - byte[] result = ReadLongstr(slice); + byte[] result = ReadLongstr(span.Slice(1)); bytesRead += result.Length + 4; return result; case 'I': bytesRead += 4; - return NetworkOrderDeserializer.ReadInt32(slice.Span); + return NetworkOrderDeserializer.ReadInt32(span.Slice(1)); case 'i': bytesRead += 4; - return NetworkOrderDeserializer.ReadUInt32(slice.Span); + return NetworkOrderDeserializer.ReadUInt32(span.Slice(1)); case 'D': bytesRead += 5; - return ReadDecimal(slice); + return ReadDecimal(span.Slice(1)); case 'T': bytesRead += 8; - return ReadTimestamp(slice); + return ReadTimestamp(span.Slice(1)); case 'F': - Dictionary tableResult = ReadTable(slice, out int tableBytesRead); + Dictionary tableResult = ReadTable(span.Slice(1), out int tableBytesRead); bytesRead += tableBytesRead; return tableResult; case 'A': - IList arrayResult = ReadArray(slice, out int arrayBytesRead); + IList arrayResult = ReadArray(span.Slice(1), out int arrayBytesRead); bytesRead += arrayBytesRead; return arrayResult; case 'B': bytesRead += 1; - return slice.Span[0]; + return span[1]; case 'b': bytesRead += 1; - return (sbyte)slice.Span[0]; + return (sbyte)span[1]; case 'd': bytesRead += 8; - return NetworkOrderDeserializer.ReadDouble(slice.Span); + return NetworkOrderDeserializer.ReadDouble(span.Slice(1)); case 'f': bytesRead += 4; - return NetworkOrderDeserializer.ReadSingle(slice.Span); + return NetworkOrderDeserializer.ReadSingle(span.Slice(1)); case 'l': bytesRead += 8; - return NetworkOrderDeserializer.ReadInt64(slice.Span); + return NetworkOrderDeserializer.ReadInt64(span.Slice(1)); case 's': bytesRead += 2; - return NetworkOrderDeserializer.ReadInt16(slice.Span); + return NetworkOrderDeserializer.ReadInt16(span.Slice(1)); case 't': bytesRead += 1; - return slice.Span[0] != 0; + return span[1] != 0; case 'x': - byte[] binaryTableResult = ReadLongstr(slice); + byte[] binaryTableResult = ReadLongstr(span.Slice(1)); bytesRead += binaryTableResult.Length + 4; return new BinaryTableValue(binaryTableResult); case 'V': return null; default: - throw new SyntaxErrorException($"Unrecognised type in table: {(char)memory.Span[0]}"); + throw new SyntaxErrorException($"Unrecognised type in table: {(char)span[0]}"); } } - public static byte[] ReadLongstr(ReadOnlyMemory memory) + public static byte[] ReadLongstr(ReadOnlySpan span) { - int byteCount = (int)NetworkOrderDeserializer.ReadUInt32(memory.Span); + uint byteCount = NetworkOrderDeserializer.ReadUInt32(span); if (byteCount > int.MaxValue) { throw new SyntaxErrorException($"Long string too long; byte length={byteCount}, max={int.MaxValue}"); } - return memory.Slice(4, byteCount).ToArray(); + return span.Slice(4, (int)byteCount).ToArray(); } - public static string ReadShortstr(ReadOnlyMemory memory, out int bytesRead) + public static unsafe string ReadShortstr(ReadOnlySpan span, out int bytesRead) { - int byteCount = memory.Span[0]; - ReadOnlyMemory stringSlice = memory.Slice(1, byteCount); - if (MemoryMarshal.TryGetArray(stringSlice, out ArraySegment segment)) + int byteCount = span[0]; + if (byteCount == 0) + { + bytesRead = 1; + return string.Empty; + } + if (span.Length >= byteCount + 1) { bytesRead = 1 + byteCount; - return Encoding.UTF8.GetString(segment.Array, segment.Offset, segment.Count); + fixed (byte* bytes = &span.Slice(1).GetPinnableReference()) + { + return Encoding.UTF8.GetString(bytes, byteCount); + } } - throw new InvalidOperationException("Unable to get ArraySegment from memory"); + throw new ArgumentOutOfRangeException(nameof(span), $"Span has not enough space ({span.Length} instead of {byteCount + 1})"); } ///Reads an AMQP "table" definition from the reader. @@ -208,16 +213,21 @@ public static string ReadShortstr(ReadOnlyMemory memory, out int bytesRead /// x and V types and the AMQP 0-9-1 A type. /// /// A . - public static Dictionary ReadTable(ReadOnlyMemory memory, out int bytesRead) + public static Dictionary ReadTable(ReadOnlySpan span, out int bytesRead) { - Dictionary table = new Dictionary(); - long tableLength = NetworkOrderDeserializer.ReadUInt32(memory.Span); bytesRead = 4; + long tableLength = NetworkOrderDeserializer.ReadUInt32(span); + if (tableLength == 0) + { + return null; + } + + Dictionary table = new Dictionary(); while ((bytesRead - 4) < tableLength) { - string key = ReadShortstr(memory.Slice(bytesRead), out int keyBytesRead); + string key = ReadShortstr(span.Slice(bytesRead), out int keyBytesRead); bytesRead += keyBytesRead; - object value = ReadFieldValue(memory.Slice(bytesRead), out int valueBytesRead); + object value = ReadFieldValue(span.Slice(bytesRead), out int valueBytesRead); bytesRead += valueBytesRead; if (!table.ContainsKey(key)) @@ -229,19 +239,19 @@ public static Dictionary ReadTable(ReadOnlyMemory memory, return table; } - public static AmqpTimestamp ReadTimestamp(ReadOnlyMemory memory) + public static AmqpTimestamp ReadTimestamp(ReadOnlySpan span) { - ulong stamp = NetworkOrderDeserializer.ReadUInt64(memory.Span); + ulong stamp = NetworkOrderDeserializer.ReadUInt64(span); // 0-9 is afaict silent on the signedness of the timestamp. // See also MethodArgumentWriter.WriteTimestamp and AmqpTimestamp itself return new AmqpTimestamp((long)stamp); } - public static int WriteArray(Memory memory, IList val) + public static int WriteArray(Span span, IList val) { if (val == null) { - NetworkOrderSerializer.WriteUInt32(memory.Span, 0); + NetworkOrderSerializer.WriteUInt32(span, 0); return 4; } else @@ -249,10 +259,10 @@ public static int WriteArray(Memory memory, IList val) int bytesWritten = 0; for (int index = 0; index < val.Count; index++) { - bytesWritten += WriteFieldValue(memory.Slice(4 + bytesWritten), val[index]); + bytesWritten += WriteFieldValue(span.Slice(4 + bytesWritten), val[index]); } - NetworkOrderSerializer.WriteUInt32(memory.Span, (uint)bytesWritten); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); return 4 + bytesWritten; } } @@ -273,87 +283,80 @@ public static int GetArrayByteCount(IList val) return byteCount; } - public static int WriteDecimal(Memory memory, decimal value) + public static int WriteDecimal(Span span, decimal value) { DecimalToAmqp(value, out byte scale, out int mantissa); - memory.Span[0] = scale; - return 1 + WriteLong(memory.Slice(1), (uint)mantissa); + span[0] = scale; + return 1 + WriteLong(span.Slice(1), (uint)mantissa); } - public static int WriteFieldValue(Memory memory, object value) + public static int WriteFieldValue(Span span, object value) { if (value == null) { - memory.Span[0] = (byte)'V'; + span[0] = (byte)'V'; return 1; } - Memory slice = memory.Slice(1); + Span slice = span.Slice(1); switch (value) { case string val: - memory.Span[0] = (byte)'S'; - if (MemoryMarshal.TryGetArray(memory, out ArraySegment segment)) - { - int bytesWritten = Encoding.UTF8.GetBytes(val, 0, val.Length, segment.Array, segment.Offset + 5); - NetworkOrderSerializer.WriteUInt32(slice.Span, (uint)bytesWritten); - return 5 + bytesWritten; - } - - throw new WireFormattingException("Unable to get array segment from memory."); + span[0] = (byte)'S'; + return 1 + WriteLongstr(slice, val); case byte[] val: - memory.Span[0] = (byte)'S'; + span[0] = (byte)'S'; return 1 + WriteLongstr(slice, val); case int val: - memory.Span[0] = (byte)'I'; - NetworkOrderSerializer.WriteInt32(slice.Span, val); + span[0] = (byte)'I'; + NetworkOrderSerializer.WriteInt32(slice, val); return 5; case uint val: - memory.Span[0] = (byte)'i'; - NetworkOrderSerializer.WriteUInt32(slice.Span, val); + span[0] = (byte)'i'; + NetworkOrderSerializer.WriteUInt32(slice, val); return 5; case decimal val: - memory.Span[0] = (byte)'D'; + span[0] = (byte)'D'; return 1 + WriteDecimal(slice, val); case AmqpTimestamp val: - memory.Span[0] = (byte)'T'; + span[0] = (byte)'T'; return 1 + WriteTimestamp(slice, val); case IDictionary val: - memory.Span[0] = (byte)'F'; + span[0] = (byte)'F'; return 1 + WriteTable(slice, val); case IList val: - memory.Span[0] = (byte)'A'; + span[0] = (byte)'A'; return 1 + WriteArray(slice, val); case byte val: - memory.Span[0] = (byte)'B'; - memory.Span[1] = val; + span[0] = (byte)'B'; + span[1] = val; return 2; case sbyte val: - memory.Span[0] = (byte)'b'; - memory.Span[1] = (byte)val; + span[0] = (byte)'b'; + span[1] = (byte)val; return 2; case double val: - memory.Span[0] = (byte)'d'; - NetworkOrderSerializer.WriteDouble(slice.Span, val); + span[0] = (byte)'d'; + NetworkOrderSerializer.WriteDouble(slice, val); return 9; case float val: - memory.Span[0] = (byte)'f'; - NetworkOrderSerializer.WriteSingle(slice.Span, val); + span[0] = (byte)'f'; + NetworkOrderSerializer.WriteSingle(slice, val); return 5; case long val: - memory.Span[0] = (byte)'l'; - NetworkOrderSerializer.WriteInt64(slice.Span, val); + span[0] = (byte)'l'; + NetworkOrderSerializer.WriteInt64(slice, val); return 9; case short val: - memory.Span[0] = (byte)'s'; - NetworkOrderSerializer.WriteInt16(slice.Span, val); + span[0] = (byte)'s'; + NetworkOrderSerializer.WriteInt16(slice, val); return 3; case bool val: - memory.Span[0] = (byte)'t'; - memory.Span[1] = (byte)(val ? 1 : 0); + span[0] = (byte)'t'; + span[1] = (byte)(val ? 1 : 0); return 2; case BinaryTableValue val: - memory.Span[0] = (byte)'x'; + span[0] = (byte)'x'; return 1 + WriteLongstr(slice, val.Bytes); default: throw new WireFormattingException($"Value of type '{value.GetType().Name}' cannot appear as table value", value); @@ -397,64 +400,69 @@ public static int GetFieldValueByteCount(object value) } } - public static int WriteLong(Memory memory, uint val) + public static int WriteLong(Span span, uint val) { - NetworkOrderSerializer.WriteUInt32(memory.Span, val); + NetworkOrderSerializer.WriteUInt32(span, val); return 4; } - public static int WriteLonglong(Memory memory, ulong val) + public static int WriteLonglong(Span span, ulong val) { - NetworkOrderSerializer.WriteUInt64(memory.Span, val); + NetworkOrderSerializer.WriteUInt64(span, val); return 8; } - public static int WriteLongstr(Memory memory, byte[] val) + public static int WriteLongstr(Span span, ReadOnlySpan val) { - return WriteLongstr(memory, val, 0, val.Length); + WriteLong(span, (uint)val.Length); + val.CopyTo(span.Slice(4)); + return 4 + val.Length; } - public static int WriteLongstr(Memory memory, byte[] val, int index, int count) + public static int WriteShort(Span span, ushort val) { - WriteLong(memory, (uint)count); - val.AsMemory(index, count).CopyTo(memory.Slice(4)); - return 4 + count; + NetworkOrderSerializer.WriteUInt16(span, val); + return 2; } - public static int WriteShort(Memory memory, ushort val) + public static unsafe int WriteShortstr(Span span, string val) { - NetworkOrderSerializer.WriteUInt16(memory.Span, val); - return 2; + int maxLength = span.Length - 1; + if (maxLength > byte.MaxValue) + { + maxLength = byte.MaxValue; + } + fixed (char* chars = val) + fixed (byte* bytes = &span.Slice(1).GetPinnableReference()) + { + int bytesWritten = Encoding.UTF8.GetBytes(chars, val.Length, bytes, maxLength); + span[0] = (byte)bytesWritten; + return bytesWritten + 1; + } } - public static int WriteShortstr(Memory memory, string val) + public static unsafe int WriteLongstr(Span span, string val) { - if (MemoryMarshal.TryGetArray(memory, out ArraySegment segment)) + fixed (char* chars = val) + fixed (byte* bytes = &span.Slice(4).GetPinnableReference()) { - int bytesWritten = Encoding.UTF8.GetBytes(val, 0, val.Length, segment.Array, segment.Offset + 1); - if (bytesWritten <= byte.MaxValue) - { - memory.Span[0] = (byte)bytesWritten; - return bytesWritten + 1; - } - - throw new ArgumentOutOfRangeException(nameof(val), val, "Value exceeds the maximum allowed length of 255 bytes."); + int bytesWritten = Encoding.UTF8.GetBytes(chars, val.Length, bytes, span.Length); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); + return bytesWritten + 4; } - - throw new WireFormattingException("Unable to get array segment from memory."); } - public static int WriteTable(Memory memory, IDictionary val) + public static int WriteTable(Span span, IDictionary val) { - if (val == null) + if (val == null || val.Count == 0) { - NetworkOrderSerializer.WriteUInt32(memory.Span, 0); + NetworkOrderSerializer.WriteUInt32(span, 0); return 4; } else { // Let's only write after the length header. - Memory slice = memory.Slice(4); + Span slice = span.Slice(4); int bytesWritten = 0; foreach (DictionaryEntry entry in val) { @@ -462,30 +470,41 @@ public static int WriteTable(Memory memory, IDictionary val) bytesWritten += WriteFieldValue(slice.Slice(bytesWritten), entry.Value); } - NetworkOrderSerializer.WriteUInt32(memory.Span, (uint)bytesWritten); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); return 4 + bytesWritten; } } - public static int WriteTable(Memory memory, IDictionary val) + public static int WriteTable(Span span, IDictionary val) { - if (val == null) + if (val == null || val.Count == 0) { - NetworkOrderSerializer.WriteUInt32(memory.Span, 0); + NetworkOrderSerializer.WriteUInt32(span, 0); return 4; } else { // Let's only write after the length header. - Memory slice = memory.Slice(4); + Span slice = span.Slice(4); int bytesWritten = 0; - foreach (KeyValuePair entry in val) + if (val is Dictionary dict) { - bytesWritten += WriteShortstr(slice.Slice(bytesWritten), entry.Key); - bytesWritten += WriteFieldValue(slice.Slice(bytesWritten), entry.Value); + foreach (KeyValuePair entry in dict) + { + bytesWritten += WriteShortstr(slice.Slice(bytesWritten), entry.Key); + bytesWritten += WriteFieldValue(slice.Slice(bytesWritten), entry.Value); + } + } + else + { + foreach (KeyValuePair entry in val) + { + bytesWritten += WriteShortstr(slice.Slice(bytesWritten), entry.Key); + bytesWritten += WriteFieldValue(slice.Slice(bytesWritten), entry.Value); + } } - NetworkOrderSerializer.WriteUInt32(memory.Span, (uint)bytesWritten); + NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten); return 4 + bytesWritten; } } @@ -515,20 +534,31 @@ public static int GetTableByteCount(IDictionary val) return byteCount; } - foreach (KeyValuePair entry in val) + if (val is Dictionary dict) { - byteCount += Encoding.UTF8.GetByteCount(entry.Key) + 1; - byteCount += GetFieldValueByteCount(entry.Value); + foreach (KeyValuePair entry in dict) + { + byteCount += Encoding.UTF8.GetByteCount(entry.Key) + 1; + byteCount += GetFieldValueByteCount(entry.Value); + } + } + else + { + foreach (KeyValuePair entry in val) + { + byteCount += Encoding.UTF8.GetByteCount(entry.Key) + 1; + byteCount += GetFieldValueByteCount(entry.Value); + } } return byteCount; } - public static int WriteTimestamp(Memory memory, AmqpTimestamp val) + public static int WriteTimestamp(Span span, AmqpTimestamp val) { // 0-9 is afaict silent on the signedness of the timestamp. // See also MethodArgumentReader.ReadTimestamp and AmqpTimestamp itself - return WriteLonglong(memory, (ulong)val.UnixTime); + return WriteLonglong(span, (ulong)val.UnixTime); } } } diff --git a/projects/Unit/TestBasicProperties.cs b/projects/Unit/TestBasicProperties.cs index 3693309c8e..619c2d29a8 100644 --- a/projects/Unit/TestBasicProperties.cs +++ b/projects/Unit/TestBasicProperties.cs @@ -38,6 +38,7 @@ // Copyright (c) 2011-2020 VMware, Inc. or its affiliates. All rights reserved. //--------------------------------------------------------------------------- +using System; using NUnit.Framework; namespace RabbitMQ.Client.Unit @@ -104,12 +105,13 @@ public void TestNullableProperties_CanWrite( bool isMessageIdPresent = messageId != null; Assert.AreEqual(isMessageIdPresent, subject.IsMessageIdPresent()); - var writer = new Impl.ContentHeaderPropertyWriter(new byte[1024]); + Span span = new byte[1024]; + var writer = new Impl.ContentHeaderPropertyWriter(span); subject.WritePropertiesTo(ref writer); // Read from Stream var propertiesFromStream = new Framing.BasicProperties(); - var reader = new Impl.ContentHeaderPropertyReader(writer.Memory.Slice(0, writer.Offset)); + var reader = new Impl.ContentHeaderPropertyReader(span.Slice(0, writer.Offset)); propertiesFromStream.ReadPropertiesFrom(ref reader); Assert.AreEqual(clusterId, propertiesFromStream.ClusterId); @@ -137,12 +139,13 @@ public void TestProperties_ReplyTo([Values(null, "foo_1", "fanout://name/key")] string replyToAddress = result?.ToString(); Assert.AreEqual(isReplyToPresent, subject.IsReplyToPresent()); - var writer = new Impl.ContentHeaderPropertyWriter(new byte[1024]); + Span span = new byte[1024]; + var writer = new Impl.ContentHeaderPropertyWriter(span); subject.WritePropertiesTo(ref writer); // Read from Stream var propertiesFromStream = new Framing.BasicProperties(); - var reader = new Impl.ContentHeaderPropertyReader(writer.Memory.Slice(0, writer.Offset)); + var reader = new Impl.ContentHeaderPropertyReader(span.Slice(0, writer.Offset)); propertiesFromStream.ReadPropertiesFrom(ref reader); Assert.AreEqual(replyTo, propertiesFromStream.ReplyTo); diff --git a/projects/Unit/TestContentHeaderCodec.cs b/projects/Unit/TestContentHeaderCodec.cs index a2515b9455..2cfa8524f8 100644 --- a/projects/Unit/TestContentHeaderCodec.cs +++ b/projects/Unit/TestContentHeaderCodec.cs @@ -71,19 +71,21 @@ public void Check(ReadOnlyMemory actual, ReadOnlyMemory expected) [Test] public void TestPresence() { - var m_w = new ContentHeaderPropertyWriter(new byte[1024]); + var memory = new byte[1024]; + var m_w = new ContentHeaderPropertyWriter(memory); m_w.WritePresence(false); m_w.WritePresence(true); m_w.WritePresence(false); m_w.WritePresence(true); m_w.FinishPresence(); - Check(m_w.Memory.Slice(0, m_w.Offset), new byte[] { 0x50, 0x00 }); + Check(memory.AsMemory().Slice(0, m_w.Offset), new byte[] { 0x50, 0x00 }); } [Test] public void TestLongPresence() { - var m_w = new ContentHeaderPropertyWriter(new byte[1024]); + var memory = new byte[1024]; + var m_w = new ContentHeaderPropertyWriter(memory); m_w.WritePresence(false); m_w.WritePresence(true); @@ -95,15 +97,16 @@ public void TestLongPresence() } m_w.WritePresence(true); m_w.FinishPresence(); - Check(m_w.Memory.Slice(0, m_w.Offset), new byte[] { 0x50, 0x01, 0x00, 0x40 }); + Check(memory.AsMemory().Slice(0, m_w.Offset), new byte[] { 0x50, 0x01, 0x00, 0x40 }); } [Test] public void TestNoPresence() { - var m_w = new ContentHeaderPropertyWriter(new byte[1024]); + var memory = new byte[1024]; + var m_w = new ContentHeaderPropertyWriter(memory); m_w.FinishPresence(); - Check(m_w.Memory.Slice(0, m_w.Offset), new byte[] { 0x00, 0x00 }); + Check(memory.AsMemory().Slice(0, m_w.Offset), new byte[] { 0x00, 0x00 }); } [Test] diff --git a/projects/Unit/TestFieldTableFormatting.cs b/projects/Unit/TestFieldTableFormatting.cs index 1025adb7a4..5da8a035b3 100644 --- a/projects/Unit/TestFieldTableFormatting.cs +++ b/projects/Unit/TestFieldTableFormatting.cs @@ -139,7 +139,7 @@ [new string('A', TooLarge)] = null int bytesNeeded = WireFormatting.GetTableByteCount(t); byte[] bytes = new byte[bytesNeeded]; - Assert.Throws(() => WireFormatting.WriteTable(bytes, t)); + Assert.Throws(() => WireFormatting.WriteTable(bytes, t)); } [Test] diff --git a/projects/Unit/TestFrameFormatting.cs b/projects/Unit/TestFrameFormatting.cs new file mode 100644 index 0000000000..18c2d06439 --- /dev/null +++ b/projects/Unit/TestFrameFormatting.cs @@ -0,0 +1,70 @@ +// This source code is dual-licensed under the Apache License, version +// 2.0, and the Mozilla Public License, version 1.1. +// +// The APL v2.0: +// +//--------------------------------------------------------------------------- +// Copyright (c) 2007-2020 VMware, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//--------------------------------------------------------------------------- +// +// The MPL v1.1: +// +//--------------------------------------------------------------------------- +// The contents of this file are subject to the Mozilla Public License +// Version 1.1 (the "License"); you may not use this file except in +// compliance with the License. You may obtain a copy of the License +// at https://www.mozilla.org/MPL/ +// +// Software distributed under the License is distributed on an "AS IS" +// basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See +// the License for the specific language governing rights and +// limitations under the License. +// +// The Original Code is RabbitMQ. +// +// The Initial Developer of the Original Code is Pivotal Software, Inc. +// Copyright (c) 2007-2020 VMware, Inc. All rights reserved. +//--------------------------------------------------------------------------- + +using NUnit.Framework; + +using RabbitMQ.Client.Impl; + +namespace RabbitMQ.Client.Unit +{ + [TestFixture] + class TestFrameFormatting : WireFormattingFixture + { + [Test] + public void EmptyOutboundFrame() + { + var frame = new EmptyOutboundFrame(); + var body = new byte[frame.GetMinimumBufferSize()]; + + frame.WriteTo(body); + + Assert.AreEqual(0, frame.GetMinimumPayloadBufferSize()); + Assert.AreEqual(8, frame.GetMinimumBufferSize()); + Assert.AreEqual(Constants.FrameHeartbeat, body[0]); + Assert.AreEqual(0, body[1]); // channel + Assert.AreEqual(0, body[2]); // channel + Assert.AreEqual(0, body[3]); // payload size + Assert.AreEqual(0, body[4]); // payload size + Assert.AreEqual(0, body[5]); // payload size + Assert.AreEqual(0, body[6]); // payload size + Assert.AreEqual(Constants.FrameEnd, body[7]); + } + } +} diff --git a/projects/Unit/TestMethodArgumentCodec.cs b/projects/Unit/TestMethodArgumentCodec.cs index c2aefa196a..b1e56d1028 100644 --- a/projects/Unit/TestMethodArgumentCodec.cs +++ b/projects/Unit/TestMethodArgumentCodec.cs @@ -62,14 +62,8 @@ public static MethodArgumentReader Reader(byte[] bytes) return new MethodArgumentReader(bytes); } - public byte[] Contents(MethodArgumentWriter w) + public void Check(byte[] actual, byte[] expected) { - return w.Memory.Slice(0, w.Offset).ToArray(); - } - - public void Check(MethodArgumentWriter w, byte[] expected) - { - byte[] actual = Contents(w); try { Assert.AreEqual(expected, actual); @@ -95,10 +89,11 @@ public void TestTableLengthWrite() }; int bytesNeeded = WireFormatting.GetTableByteCount(t); - var writer = new MethodArgumentWriter(new byte[bytesNeeded]); + byte[] memory = new byte[bytesNeeded]; + var writer = new MethodArgumentWriter(memory); writer.WriteTable(t); Assert.AreEqual(bytesNeeded, writer.Offset); - Check(writer, new byte[] { 0x00, 0x00, 0x00, 0x0C, + Check(memory, new byte[] { 0x00, 0x00, 0x00, 0x0C, 0x03, 0x61, 0x62, 0x63, 0x53, 0x00, 0x00, 0x00, 0x03, 0x64, 0x65, 0x66 }); @@ -125,10 +120,11 @@ public void TestNestedTableWrite() }; t["x"] = x; int bytesNeeded = WireFormatting.GetTableByteCount(t); - var writer = new MethodArgumentWriter(new byte[bytesNeeded]); + byte[] memory = new byte[bytesNeeded]; + var writer = new MethodArgumentWriter(memory); writer.WriteTable(t); Assert.AreEqual(bytesNeeded, writer.Offset); - Check(writer, new byte[] { 0x00, 0x00, 0x00, 0x0E, + Check(memory, new byte[] { 0x00, 0x00, 0x00, 0x0E, 0x01, 0x78, 0x46, 0x00, 0x00, 0x00, 0x07, 0x01, 0x79, 0x49, 0x12, 0x34,