Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify managed SNI receive callback use #1186

Merged
merged 2 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,6 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNILoadHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIMarsConnection.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIMarsHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIMarsQueuedPacket.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNINpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPacket.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ internal abstract class SNIHandle
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public abstract uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null);
public abstract uint SendAsync(SNIPacket packet);

/// <summary>
/// Receive a packet synchronously
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,15 @@ public uint Send(SNIPacket packet)
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public uint SendAsync(SNIPacket packet, SNIAsyncCallback callback)
public uint SendAsync(SNIPacket packet)
{
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(s_className);
try
{
lock (this)
{
return _lowerHandle.SendAsync(packet, callback);
return _lowerHandle.SendAsync(packet);
}
}
finally
Expand Down Expand Up @@ -192,7 +191,7 @@ public void HandleReceiveError(SNIPacket packet)
Debug.Assert(Monitor.IsEntered(this), "HandleReceiveError was called without being locked.");
foreach (SNIMarsHandle handle in _sessions.Values)
{
if (packet.HasCompletionCallback)
if (packet.HasAsyncIOCompletionCallback)
{
handle.HandleReceiveError(packet);
#if DEBUG
Expand All @@ -215,7 +214,7 @@ public void HandleReceiveError(SNIPacket packet)
/// <param name="sniErrorCode">SNI error code</param>
public void HandleSendComplete(SNIPacket packet, uint sniErrorCode)
{
packet.InvokeCompletionCallback(sniErrorCode);
packet.InvokeAsyncIOCompletionCallback(sniErrorCode);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ internal sealed class SNIMarsHandle : SNIHandle
private readonly SNIMarsConnection _connection;
private readonly uint _status = TdsEnums.SNI_UNINITIALIZED;
private readonly Queue<SNIPacket> _receivedPacketQueue = new Queue<SNIPacket>();
private readonly Queue<SNIMarsQueuedPacket> _sendPacketQueue = new Queue<SNIMarsQueuedPacket>();
private readonly Queue<SNIPacket> _sendPacketQueue = new Queue<SNIPacket>();
private readonly object _callbackObject;
private readonly Guid _connectionId;
private readonly ushort _sessionId;
Expand Down Expand Up @@ -191,9 +191,8 @@ public override uint Send(SNIPacket packet)
/// Send packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback)
private uint InternalSendAsync(SNIPacket packet)
{
Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to send muxed packet without smux reservation in InternalSendAsync");
using (TrySNIEventScope.Create("SNIMarsHandle.InternalSendAsync | SNI | INFO | SCOPE | Entering Scope {0}"))
Expand All @@ -207,9 +206,9 @@ private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback)
}

SNIPacket muxedPacket = SetPacketSMUXHeader(packet);
muxedPacket.SetCompletionCallback(callback ?? HandleSendComplete);
muxedPacket.SetAsyncIOCompletionCallback(_handleSendCompleteCallback);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsHandle), EventType.INFO, "MARS Session Id {0}, _sequenceNumber {1}, _sendHighwater {2}, Sending packet", args0: ConnectionId, args1: _sequenceNumber, args2: _sendHighwater);
return _connection.SendAsync(muxedPacket, callback);
return _connection.SendAsync(muxedPacket);
}
}
}
Expand All @@ -222,7 +221,7 @@ private uint SendPendingPackets()
{
using (TrySNIEventScope.Create(nameof(SNIMarsHandle)))
{
SNIMarsQueuedPacket packet = null;
SNIPacket packet = null;

while (true)
{
Expand All @@ -233,7 +232,7 @@ private uint SendPendingPackets()
if (_sendPacketQueue.Count != 0)
{
packet = _sendPacketQueue.Peek();
uint result = InternalSendAsync(packet.Packet, packet.Callback);
uint result = InternalSendAsync(packet);

if (result != TdsEnums.SNI_SUCCESS && result != TdsEnums.SNI_SUCCESS_IO_PENDING)
{
Expand Down Expand Up @@ -264,15 +263,15 @@ private uint SendPendingPackets()
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
using (TrySNIEventScope.Create(nameof(SNIMarsHandle)))
{
packet.SetAsyncIOCompletionCallback(_handleSendCompleteCallback);
lock (this)
{
_sendPacketQueue.Enqueue(new SNIMarsQueuedPacket(packet, callback ?? _handleSendCompleteCallback));
_sendPacketQueue.Enqueue(packet);
}

SendPendingPackets();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ public override uint ReceiveAsync(ref SNIPacket packet)
{
SNIPacket errorPacket;
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);

packet.SetAsyncIOCompletionCallback(_receiveCallback);
try
{
packet.ReadFromStreamAsync(_stream, _receiveCallback);
packet.ReadFromStreamAsync(_stream);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, "Connection Id {0}, Rented and read packet asynchronously, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
Expand Down Expand Up @@ -288,13 +288,12 @@ public override uint Send(SNIPacket packet)
}
}

public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
using (TrySNIEventScope.Create(nameof(SNINpHandle)))
{
SNIAsyncCallback cb = callback ?? _sendCallback;
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, "Connection Id {0}, Packet writing to stream, dataLeft {1}", args0: _connectionId, args1: packet?.DataLeft);
packet.WriteToStreamAsync(_stream, cb, SNIProviders.NP_PROV);
packet.WriteToStreamAsync(_stream, _sendCallback, SNIProviders.NP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
Expand All @@ -19,14 +18,14 @@ namespace Microsoft.Data.SqlClient.SNI
/// </summary>
internal sealed class SNIPacket
{
private static readonly Action<Task<int>, object> s_readCallback = ReadFromStreamAsyncContinuation;
private int _dataLength; // the length of the data in the data segment, advanced by Append-ing data, does not include smux header length
private int _dataCapacity; // the total capacity requested, if the array is rented this may be less than the _data.Length, does not include smux header length
private int _dataOffset; // the start point of the data in the data segment, advanced by Take-ing data
private int _headerLength; // the amount of space at the start of the array reserved for the smux header, this is zeroed in SetHeader
// _headerOffset is not needed because it is always 0
private byte[] _data;
private SNIAsyncCallback _completionCallback;
private readonly Action<Task<int>, object> _readCallback;
private SNIAsyncCallback _asyncIOCompletionCallback;
#if DEBUG
internal readonly int _id; // in debug mode every packet is assigned a unique id so that the entire lifetime can be tracked when debugging
/// refcount = 0 means that a packet should only exist in the pool
Expand Down Expand Up @@ -85,7 +84,6 @@ public SNIPacket(SNIHandle owner, int id)
#endif
public SNIPacket()
{
_readCallback = ReadFromStreamAsyncContinuation;
}

/// <summary>
Expand All @@ -110,25 +108,19 @@ public SNIPacket()

public int ReservedHeaderSize => _headerLength;

public bool HasCompletionCallback => !(_completionCallback is null);
public bool HasAsyncIOCompletionCallback => _asyncIOCompletionCallback is not null;

/// <summary>
/// Set async completion callback
/// Set async receive callback
/// </summary>
/// <param name="completionCallback">Completion callback</param>
public void SetCompletionCallback(SNIAsyncCallback completionCallback)
{
_completionCallback = completionCallback;
}
/// <param name="asyncIOCompletionCallback">Completion callback</param>
public void SetAsyncIOCompletionCallback(SNIAsyncCallback asyncIOCompletionCallback) => _asyncIOCompletionCallback = asyncIOCompletionCallback;

/// <summary>
/// Invoke the completion callback
/// Invoke the receive callback
/// </summary>
/// <param name="sniErrorCode">SNI error</param>
public void InvokeCompletionCallback(uint sniErrorCode)
{
_completionCallback(this, sniErrorCode);
}
public void InvokeAsyncIOCompletionCallback(uint sniErrorCode) => _asyncIOCompletionCallback(this, sniErrorCode);

/// <summary>
/// Allocate space for data
Expand Down Expand Up @@ -253,7 +245,7 @@ public void Release()
_dataLength = 0;
_dataOffset = 0;
_headerLength = 0;
_completionCallback = null;
_asyncIOCompletionCallback = null;
IsOutOfBand = false;
}

Expand All @@ -273,49 +265,48 @@ public void ReadFromStream(Stream stream)
/// Read data from a stream asynchronously
/// </summary>
/// <param name="stream">Stream to read from</param>
/// <param name="callback">Completion callback</param>
public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
public void ReadFromStreamAsync(Stream stream)
{
stream.ReadAsync(_data, 0, _dataCapacity, CancellationToken.None)
.ContinueWith(
continuationAction: _readCallback,
state: callback,
continuationAction: s_readCallback,
state: this,
CancellationToken.None,
TaskContinuationOptions.DenyChildAttach,
TaskScheduler.Default
);
}

private void ReadFromStreamAsyncContinuation(Task<int> t, object state)
private static void ReadFromStreamAsyncContinuation(Task<int> task, object state)
{
SNIAsyncCallback callback = (SNIAsyncCallback)state;
SNIPacket packet = (SNIPacket)state;
bool error = false;
Exception e = t.Exception?.InnerException;
Exception e = task.Exception?.InnerException;
if (e != null)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, e);
#if DEBUG
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIPacket), EventType.ERR, "Connection Id {0}, Internal Exception occurred while reading data: {1}", args0: _owner?.ConnectionId, args1: e?.Message);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIPacket), EventType.ERR, "Connection Id {0}, Internal Exception occurred while reading data: {1}", args0: packet._owner?.ConnectionId, args1: e?.Message);
#endif
error = true;
}
else
{
_dataLength = t.Result;
packet._dataLength = task.Result;
#if DEBUG
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIPacket), EventType.INFO, "Connection Id {0}, Packet Id {1} _dataLength {2} read from stream.", args0: _owner?.ConnectionId, args1: _id, args2: _dataLength);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIPacket), EventType.INFO, "Connection Id {0}, Packet Id {1} _dataLength {2} read from stream.", args0: packet._owner?.ConnectionId, args1: packet._id, args2: packet._dataLength);
#endif
if (_dataLength == 0)
if (packet._dataLength == 0)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, Strings.SNI_ERROR_2);
#if DEBUG
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIPacket), EventType.ERR, "Connection Id {0}, No data read from stream, connection was terminated.", args0: _owner?.ConnectionId);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIPacket), EventType.ERR, "Connection Id {0}, No data read from stream, connection was terminated.", args0: packet._owner?.ConnectionId);
#endif
error = true;
}
}

callback(this, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
packet.InvokeAsyncIOCompletionCallback(error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,14 +802,12 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// Send a packet asynchronously
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="callback">Completion callback</param>
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
public override uint SendAsync(SNIPacket packet)
{
using (TrySNIEventScope.Create(nameof(SNITCPHandle)))
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
packet.WriteToStreamAsync(_stream, _sendCallback, SNIProviders.TCP_PROV);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Data sent to stream asynchronously", args0: _connectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
Expand All @@ -824,10 +822,10 @@ public override uint ReceiveAsync(ref SNIPacket packet)
{
SNIPacket errorPacket;
packet = RentPacket(headerSize: 0, dataSize: _bufferSize);

packet.SetAsyncIOCompletionCallback(_receiveCallback);
try
{
packet.ReadFromStreamAsync(_stream, _receiveCallback);
packet.ReadFromStreamAsync(_stream);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Data received from stream asynchronously", args0: _connectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
Expand Down