diff --git a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/SocketConnection.cs b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/SocketConnection.cs index 35cb1e0dff..5852f976f1 100644 --- a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/SocketConnection.cs +++ b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/SocketConnection.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. - +using System.Diagnostics; using System.Diagnostics.Contracts; using System.Net; using System.Net.Sockets; @@ -14,26 +14,36 @@ namespace System.ServiceModel.Channels { internal class SocketConnection : IConnection { - private static EventHandler s_onReceiveAsyncCompleted; - private static EventHandler s_onSocketSendCompleted; + static AsyncCallback s_onReceiveCompleted; + static EventHandler s_onReceiveAsyncCompleted; + static EventHandler s_onSocketSendCompleted; // common state private Socket _socket; - private bool _noDelay = false; - private TimeSpan _sendTimeout; - private TimeSpan _receiveTimeout; + private TimeSpan _asyncSendTimeout; + private TimeSpan _readFinTimeout; + private TimeSpan _asyncReceiveTimeout; + + // Socket.SendTimeout/Socket.ReceiveTimeout only work with the synchronous API calls and therefore they + // do not get updated when asynchronous Send/Read operations are performed. In order to make sure we + // Set the proper timeouts on the Socket itself we need to keep these two additional fields. + private TimeSpan _socketSyncSendTimeout; + private TimeSpan _socketSyncReceiveTimeout; + private CloseState _closeState; + private bool _isShutdown; + private bool _noDelay = false; private bool _aborted; // close state - private static Action s_onWaitForFinComplete = new Action(OnWaitForFinComplete); private TimeoutHelper _closeTimeoutHelper; - private bool _isShutdown; + private static Action s_onWaitForFinComplete = new Action(OnWaitForFinComplete); // read state - private SocketAsyncEventArgs _asyncReadEventArgs; - private TimeSpan _readFinTimeout; private int _asyncReadSize; + private SocketAsyncEventArgs _asyncReadEventArgs; + private byte[] _readBuffer; + private int _asyncReadBufferSize; private object _asyncReadState; private Action _asyncReadCallback; private Exception _asyncReadException; @@ -46,37 +56,100 @@ internal class SocketConnection : IConnection private Exception _asyncWriteException; private bool _asyncWritePending; - private static Action s_onSendTimeout; - private static Action s_onReceiveTimeout; private IOTimer _receiveTimer; + private static Action s_onReceiveTimeout; private IOTimer _sendTimer; + private static Action s_onSendTimeout; private string _timeoutErrorString; private TransferOperation _timeoutErrorTransferOperation; + private IPEndPoint _remoteEndpoint; private ConnectionBufferPool _connectionBufferPool; - private string _remoteEndpointAddressString; + private string _remoteEndpointAddress; - public SocketConnection(Socket socket, ConnectionBufferPool connectionBufferPool) + public SocketConnection(Socket socket, ConnectionBufferPool connectionBufferPool, bool autoBindToCompletionPort) { _connectionBufferPool = connectionBufferPool ?? throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull(nameof(connectionBufferPool)); _socket = socket ?? throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull(nameof(socket)); _closeState = CloseState.Open; - AsyncReadBuffer = _connectionBufferPool.Take(); - AsyncReadBufferSize = AsyncReadBuffer.Length; - _sendTimeout = _receiveTimeout = TimeSpan.MaxValue; - _closeState = CloseState.Open; - _socket.SendBufferSize = _socket.ReceiveBufferSize = AsyncReadBufferSize; - _sendTimeout = _receiveTimeout = TimeSpan.MaxValue; - } + _readBuffer = connectionBufferPool.Take(); + _asyncReadBufferSize = _readBuffer.Length; + _socket.SendBufferSize = _socket.ReceiveBufferSize = _asyncReadBufferSize; + _asyncSendTimeout = _asyncReceiveTimeout = TimeSpan.MaxValue; + _socketSyncSendTimeout = _socketSyncReceiveTimeout = TimeSpan.MaxValue; - public int AsyncReadBufferSize { get; } + _remoteEndpoint = null; - public byte[] AsyncReadBuffer { get; private set; } + if (autoBindToCompletionPort) + { + _socket.UseOnlyOverlappedIO = false; + } + + // In SMSvcHost, sockets must be duplicated to the target process. Binding a handle to a completion port + // prevents any duplicated handle from ever binding to a completion port. The target process is where we + // want to use completion ports for performance. This means that in SMSvcHost, socket.UseOnlyOverlappedIO + // must be set to true to prevent completion port use. + if (_socket.UseOnlyOverlappedIO) + { + // Init BeginRead state + if (s_onReceiveCompleted == null) + { + s_onReceiveCompleted = Fx.ThunkCallback(new AsyncCallback(OnReceiveCompleted)); + } + } + } + public int AsyncReadBufferSize + { + get { return _asyncReadBufferSize; } + } + + public byte[] AsyncReadBuffer + { + get + { + return _readBuffer; + } + } private object ThisLock { get { return this; } } + public IPEndPoint RemoteIPEndPoint + { + get + { + // this property should only be called on the receive path + if (_remoteEndpoint == null && _closeState == CloseState.Open) + { + try + { + _remoteEndpoint = (IPEndPoint)_socket.RemoteEndPoint; + } + catch (SocketException socketException) + { + // will never be a timeout error, so TimeSpan.Zero is ok + throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( + ConvertReceiveException(socketException, TimeSpan.Zero, TimeSpan.Zero)); + } + catch (ObjectDisposedException objectDisposedException) + { + Exception exceptionToThrow = ConvertObjectDisposedException(objectDisposedException, TransferOperation.Undefined); + if (ReferenceEquals(exceptionToThrow, objectDisposedException)) + { + throw; + } + else + { + throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(exceptionToThrow); + } + } + } + + return _remoteEndpoint; + } + } + private IOTimer SendTimer { get @@ -85,7 +158,7 @@ private IOTimer SendTimer { if (s_onSendTimeout == null) { - s_onSendTimeout = OnSendTimeout; + s_onSendTimeout = new Action(OnSendTimeout); } _sendTimer = new IOTimer(s_onSendTimeout, this); @@ -103,7 +176,7 @@ private IOTimer ReceiveTimer { if (s_onReceiveTimeout == null) { - s_onReceiveTimeout = OnReceiveTimeout; + s_onReceiveTimeout = new Action(OnReceiveTimeout); } _receiveTimer = new IOTimer(s_onReceiveTimeout, this); @@ -113,116 +186,64 @@ private IOTimer ReceiveTimer } } - private IPEndPoint RemoteEndPoint - { - get - { - if (!_socket.Connected) - { - return null; - } - return (IPEndPoint)_socket.RemoteEndPoint; - } - } - - private string RemoteEndpointAddressString + private string RemoteEndpointAddress { get { - if (_remoteEndpointAddressString == null) + if (_remoteEndpointAddress == null) { - IPEndPoint remote = RemoteEndPoint; - if (remote == null) + try { - return string.Empty; + if (TryGetEndpoints(out IPEndPoint local, out IPEndPoint remote)) + { + _remoteEndpointAddress = remote.Address + ":" + remote.Port; + } + else + { + //null indicates not initialized. + _remoteEndpointAddress = string.Empty; + } } - _remoteEndpointAddressString = remote.Address + ":" + remote.Port; - } + catch (Exception exception) + { + if (Fx.IsFatal(exception)) + { + throw; + } - return _remoteEndpointAddressString; + } + } + return _remoteEndpointAddress; } } - private static void OnReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e) + private static void OnReceiveTimeout(object state) { - ((SocketConnection)e.UserToken).OnReceiveAsync(sender, e); + SocketConnection thisPtr = (SocketConnection)state; + thisPtr.Abort(SR.Format(SR.SocketAbortedReceiveTimedOut, thisPtr._asyncReceiveTimeout), TransferOperation.Read); } - private static void OnSendAsyncCompleted(object sender, SocketAsyncEventArgs e) + private static void OnSendTimeout(object state) { - ((SocketConnection)e.UserToken).OnSendAsync(sender, e); + SocketConnection thisPtr = (SocketConnection)state; + thisPtr.Abort(4, // TraceEventType.Warning + SR.Format(SR.SocketAbortedSendTimedOut, thisPtr._asyncSendTimeout), TransferOperation.Write); } - private static void OnWaitForFinComplete(object state) + private static void OnReceiveCompleted(IAsyncResult result) { - // Callback for read on a socket which has had Shutdown called on it. When - // the response FIN packet is received from the remote host, the pending - // read will complete with 0 bytes read. If more than 0 bytes has been read, - // then something has gone wrong as we should have no pending data to be received. - SocketConnection thisPtr = (SocketConnection)state; - - try - { - int bytesRead; - - try - { - bytesRead = thisPtr.EndRead(); - - if (bytesRead > 0) - { - throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( - new CommunicationException(SR.Format(SR.SocketCloseReadReceivedData, thisPtr.RemoteEndPoint))); - } - } - catch (TimeoutException timeoutException) - { - throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException( - SR.Format(SR.SocketCloseReadTimeout, thisPtr.RemoteEndPoint, thisPtr._readFinTimeout), - timeoutException)); - } - - thisPtr.ContinueClose(thisPtr._closeTimeoutHelper.RemainingTime()); - } - catch (Exception e) - { - if (Fx.IsFatal(e)) - { - throw; - } - - Fx.Exception.TraceUnhandledException(e); - - // The user has no opportunity to clean up the connection in the async and linger - // code path, ensure cleanup finishes. - thisPtr.Abort(); - } + ((SocketConnection)result.AsyncState).OnReceive(result); } - private static void OnReceiveTimeout(SocketConnection socketConnection) + private static void OnReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e) { - try - { - socketConnection.Abort(SR.Format(SR.SocketAbortedReceiveTimedOut, socketConnection._receiveTimeout), TransferOperation.Read); - } - catch (SocketException) - { - // Guard against unhandled SocketException in timer callbacks - } + ((SocketConnection)e.UserToken).OnReceiveAsync(sender, e); } - private static void OnSendTimeout(SocketConnection socketConnection) + private static void OnSendAsyncCompleted(object sender, SocketAsyncEventArgs e) { - try - { - socketConnection.Abort(4, // TraceEventType.Warning - SR.Format(SR.SocketAbortedSendTimedOut, socketConnection._sendTimeout), TransferOperation.Write); - } - catch (SocketException) - { - // Guard against unhandled SocketException in timer callbacks - } + ((SocketConnection)e.UserToken).OnSendAsync(sender, e); } public void Abort() @@ -258,23 +279,26 @@ private void Abort(int traceEventType, string timeoutErrorString, TransferOperat _aborted = true; _closeState = CloseState.Closed; - if (!_asyncReadPending) + if (_asyncReadPending) + { + CancelReceiveTimer(); + } + else { DisposeReadEventArgs(); } - if (!_asyncWritePending) + if (_asyncWritePending) + { + CancelSendTimer(); + } + else { DisposeWriteEventArgs(); } - - DisposeReceiveTimer(); - DisposeSendTimer(); } - _socket.LingerState = new LingerOption(true, 0); - _socket.Shutdown(SocketShutdown.Both); - _socket.Dispose(); + _socket.Close(0); } private void AbortRead() @@ -326,23 +350,61 @@ private void CloseAsyncAndLinger() int bytesRead = EndRead(); - // Any NetTcp session handshake will have been completed at this point so if any data is returned, something - // very wrong has happened. if (bytesRead > 0) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( - new CommunicationException(SR.Format(SR.SocketCloseReadReceivedData, RemoteEndPoint))); + new CommunicationException(SR.Format(SR.SocketCloseReadReceivedData, _socket.RemoteEndPoint))); } } catch (TimeoutException timeoutException) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException( - SR.Format(SR.SocketCloseReadTimeout, RemoteEndPoint, _readFinTimeout), timeoutException)); + SR.Format(SR.SocketCloseReadTimeout, _socket.RemoteEndPoint, _readFinTimeout), timeoutException)); } ContinueClose(_closeTimeoutHelper.RemainingTime()); } + private static void OnWaitForFinComplete(object state) + { + SocketConnection thisPtr = (SocketConnection)state; + + try + { + int bytesRead; + + try + { + bytesRead = thisPtr.EndRead(); + + if (bytesRead > 0) + { + throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( + new CommunicationException(SR.Format(SR.SocketCloseReadReceivedData, thisPtr._socket.RemoteEndPoint))); + } + } + catch (TimeoutException timeoutException) + { + throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException( + SR.Format(SR.SocketCloseReadTimeout, thisPtr._socket.RemoteEndPoint, thisPtr._readFinTimeout), + timeoutException)); + } + + thisPtr.ContinueClose(thisPtr._closeTimeoutHelper.RemainingTime()); + } + catch (Exception e) + { + if (Fx.IsFatal(e)) + { + throw; + } + + // The user has no opportunity to clean up the connection in the async and linger + // code path, ensure cleanup finishes. + thisPtr.Abort(); + } + } + public void Close(TimeSpan timeout, bool asyncAndLinger) { lock (ThisLock) @@ -355,15 +417,10 @@ public void Close(TimeSpan timeout, bool asyncAndLinger) _closeState = CloseState.Closing; } - _closeTimeoutHelper = new TimeoutHelper(timeout); - // first we shutdown our send-side - Shutdown(timeout); - CloseCore(asyncAndLinger); - } + _closeTimeoutHelper = new TimeoutHelper(timeout); + Shutdown(_closeTimeoutHelper.RemainingTime()); - private void CloseCore(bool asyncAndLinger) - { if (asyncAndLinger) { CloseAsyncAndLinger(); @@ -378,10 +435,7 @@ private void CloseSync() { byte[] dummy = new byte[1]; - // A FIN (shutdown) packet has already been sent to the remote host and we're waiting for the remote - // host to send a FIN back. A pending read on a socket will complete returning zero bytes when a FIN - // packet is received. - + // then we check for a FIN from the other side (i.e. read zero) int bytesRead; _readFinTimeout = _closeTimeoutHelper.RemainingTime(); @@ -389,32 +443,25 @@ private void CloseSync() { bytesRead = ReadCore(dummy, 0, 1, _readFinTimeout, true); - // Any NetTcp session handshake will have been completed at this point so if any data is returned, something - // very wrong has happened. if (bytesRead > 0) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( - new CommunicationException(SR.Format(SR.SocketCloseReadReceivedData, RemoteEndPoint))); + new CommunicationException(SR.Format(SR.SocketCloseReadReceivedData, _socket.RemoteEndPoint))); } } catch (TimeoutException timeoutException) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException( - SR.Format(SR.SocketCloseReadTimeout, RemoteEndPoint, _readFinTimeout), timeoutException)); + SR.Format(SR.SocketCloseReadTimeout, _socket.RemoteEndPoint, _readFinTimeout), timeoutException)); } // finally we call Close with whatever time is remaining ContinueClose(_closeTimeoutHelper.RemainingTime()); } - private void ContinueClose(TimeSpan timeout) + public void ContinueClose(TimeSpan timeout) { - // Use linger to attempt a graceful socket shutdown. Allowing a clean shutdown handshake - // will allow the service side to close it's socket gracefully too. A hard shutdown would - // cause the server to receive an exception which affects performance and scalability. - _socket.LingerState = new LingerOption(true, (int)timeout.TotalSeconds); - _socket.Shutdown(SocketShutdown.Both); - _socket.Dispose(); + _socket.Close(TimeoutHelper.ToMilliseconds(timeout)); lock (ThisLock) { @@ -434,12 +481,10 @@ private void ContinueClose(TimeSpan timeout) } _closeState = CloseState.Closed; - DisposeReceiveTimer(); - DisposeSendTimer(); } } - private void Shutdown(TimeSpan timeout) + public void Shutdown(TimeSpan timeout) { lock (ThisLock) { @@ -451,12 +496,6 @@ private void Shutdown(TimeSpan timeout) _isShutdown = true; } - ShutdownCore(timeout); - } - - private void ShutdownCore(TimeSpan timeout) - { - // Attempt to close the socket gracefully by sending a shutdown (FIN) packet try { _socket.Shutdown(SocketShutdown.Send); @@ -464,7 +503,7 @@ private void ShutdownCore(TimeSpan timeout) catch (SocketException socketException) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( - ConvertSendException(socketException, TimeSpan.MaxValue)); + ConvertSendException(socketException, TimeSpan.MaxValue, _socketSyncSendTimeout)); } catch (ObjectDisposedException objectDisposedException) { @@ -500,29 +539,69 @@ private void ThrowIfClosed() } } - private Exception ConvertSendException(SocketException socketException, TimeSpan remainingTime) + private bool TryGetEndpoints(out IPEndPoint localIPEndpoint, out IPEndPoint remoteIPEndpoint) + { + localIPEndpoint = null; + remoteIPEndpoint = null; + + if (_closeState == CloseState.Open) + { + try + { + remoteIPEndpoint = _remoteEndpoint ?? (IPEndPoint)_socket.RemoteEndPoint; + localIPEndpoint = (IPEndPoint)_socket.LocalEndPoint; + } + catch (Exception exception) + { + if (Fx.IsFatal(exception)) + { + throw; + } + + } + } + + return localIPEndpoint != null && remoteIPEndpoint != null; + } + + public object GetCoreTransport() + { + return _socket; + } + + public IAsyncResult BeginValidate(Uri uri, AsyncCallback callback, object state) + { + return new CompletedAsyncResult(true, callback, state); + } + + public bool EndValidate(IAsyncResult result) + { + return CompletedAsyncResult.End(result); + } + + private Exception ConvertSendException(SocketException socketException, TimeSpan remainingTime, TimeSpan timeout) { - return ConvertTransferException(socketException, _sendTimeout, socketException, - _aborted, _timeoutErrorString, _timeoutErrorTransferOperation, this, remainingTime); + return ConvertTransferException(socketException, timeout, socketException, + TransferOperation.Write, _aborted, _timeoutErrorString, _timeoutErrorTransferOperation, this, remainingTime); } - private Exception ConvertReceiveException(SocketException socketException, TimeSpan remainingTime) + private Exception ConvertReceiveException(SocketException socketException, TimeSpan remainingTime, TimeSpan timeout) { - return ConvertTransferException(socketException, _receiveTimeout, socketException, - _aborted, _timeoutErrorString, _timeoutErrorTransferOperation, this, remainingTime); + return ConvertTransferException(socketException, timeout, socketException, + TransferOperation.Read, _aborted, _timeoutErrorString, _timeoutErrorTransferOperation, this, remainingTime); } internal static Exception ConvertTransferException(SocketException socketException, TimeSpan timeout, Exception originalException) { return ConvertTransferException(socketException, timeout, originalException, - false, null, TransferOperation.Undefined, null, TimeSpan.MaxValue); + TransferOperation.Undefined, false, null, TransferOperation.Undefined, null, TimeSpan.MaxValue); } private Exception ConvertObjectDisposedException(ObjectDisposedException originalException, TransferOperation transferOperation) { if (_timeoutErrorString != null) { - return ConvertTimeoutErrorException(originalException, _timeoutErrorString, _timeoutErrorTransferOperation); + return ConvertTimeoutErrorException(originalException, transferOperation, _timeoutErrorString, _timeoutErrorTransferOperation); } else if (_aborted) { @@ -535,30 +614,30 @@ private Exception ConvertObjectDisposedException(ObjectDisposedException origina } private static Exception ConvertTransferException(SocketException socketException, TimeSpan timeout, Exception originalException, - bool aborted, string timeoutErrorString, TransferOperation timeoutErrorTransferOperation, + TransferOperation transferOperation, bool aborted, string timeoutErrorString, TransferOperation timeoutErrorTransferOperation, SocketConnection socketConnection, TimeSpan remainingTime) { - if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.ERROR_INVALID_HANDLE) + if (socketException.ErrorCode == UnsafeNativeMethods.ERROR_INVALID_HANDLE) { return new CommunicationObjectAbortedException(socketException.Message, socketException); } if (timeoutErrorString != null) { - return ConvertTimeoutErrorException(originalException, timeoutErrorString, timeoutErrorTransferOperation); + return ConvertTimeoutErrorException(originalException, transferOperation, timeoutErrorString, timeoutErrorTransferOperation); } // 10053 can occur due to our timeout sockopt firing, so map to TimeoutException in that case - if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAECONNABORTED && + if (socketException.ErrorCode == UnsafeNativeMethods.WSAECONNABORTED && remainingTime <= TimeSpan.Zero) { TimeoutException timeoutException = new TimeoutException(SR.Format(SR.TcpConnectionTimedOut, timeout), originalException); return timeoutException; } - if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAENETRESET || - (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAECONNABORTED || - (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAECONNRESET) + if (socketException.ErrorCode == UnsafeNativeMethods.WSAENETRESET || + socketException.ErrorCode == UnsafeNativeMethods.WSAECONNABORTED || + socketException.ErrorCode == UnsafeNativeMethods.WSAECONNRESET) { if (aborted) { @@ -570,7 +649,7 @@ private static Exception ConvertTransferException(SocketException socketExceptio return communicationException; } } - else if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAETIMEDOUT) + else if (socketException.ErrorCode == UnsafeNativeMethods.WSAETIMEDOUT) { TimeoutException timeoutException = new TimeoutException(SR.Format(SR.TcpConnectionTimedOut, timeout), originalException); return timeoutException; @@ -579,21 +658,25 @@ private static Exception ConvertTransferException(SocketException socketExceptio { if (aborted) { - return new CommunicationObjectAbortedException(SR.Format(SR.TcpTransferError, (int)socketException.SocketErrorCode, socketException.Message), originalException); + return new CommunicationObjectAbortedException(SR.Format(SR.TcpTransferError, socketException.ErrorCode, socketException.Message), originalException); } else { - CommunicationException communicationException = new CommunicationException(SR.Format(SR.TcpTransferError, (int)socketException.SocketErrorCode, socketException.Message), originalException); + CommunicationException communicationException = new CommunicationException(SR.Format(SR.TcpTransferError, socketException.ErrorCode, socketException.Message), originalException); return communicationException; } } } - private static Exception ConvertTimeoutErrorException(Exception originalException, string timeoutErrorString, TransferOperation timeoutErrorTransferOperation) + private static Exception ConvertTimeoutErrorException(Exception originalException, + TransferOperation transferOperation, string timeoutErrorString, TransferOperation timeoutErrorTransferOperation) { - Contract.Assert(timeoutErrorString != null, "Argument timeoutErrorString must not be null."); + if (timeoutErrorString == null) + { + Fx.Assert("Argument timeoutErrorString must not be null."); + } - if (timeoutErrorTransferOperation != TransferOperation.Undefined) + if (transferOperation == timeoutErrorTransferOperation) { return new TimeoutException(timeoutErrorString, originalException); } @@ -605,38 +688,20 @@ private static Exception ConvertTimeoutErrorException(Exception originalExceptio public AsyncCompletionResult BeginWrite(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout, Action callback, object state) - { - if (WcfEventSource.Instance.SocketAsyncWriteStartIsEnabled()) - { - TraceWriteStart(size, true); - } - - return BeginWriteCore(buffer, offset, size, immediate, timeout, callback, state); - } - - private void TraceWriteStart(int size, bool async) - { - if (!async) - { - WcfEventSource.Instance.SocketWriteStart(_socket.GetHashCode(), size, RemoteEndpointAddressString); - } - else - { - WcfEventSource.Instance.SocketAsyncWriteStart(_socket.GetHashCode(), size, RemoteEndpointAddressString); - } - } - - private AsyncCompletionResult BeginWriteCore(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout, - Action callback, object state) { ConnectionUtilities.ValidateBufferBounds(buffer, offset, size); bool abortWrite = true; try { + if (WcfEventSource.Instance.SocketAsyncWriteStartIsEnabled()) + { + TraceWriteStart(size, true); + } + lock (ThisLock) { - Contract.Assert(!_asyncWritePending, "Called BeginWrite twice."); + Fx.Assert(!_asyncWritePending, "Called BeginWrite twice."); ThrowIfClosed(); EnsureWriteEventArgs(); SetImmediate(immediate); @@ -662,7 +727,7 @@ private AsyncCompletionResult BeginWriteCore(byte[] buffer, int offset, int size catch (SocketException socketException) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( - ConvertSendException(socketException, TimeSpan.MaxValue)); + ConvertSendException(socketException, TimeSpan.MaxValue, _asyncSendTimeout)); } catch (ObjectDisposedException objectDisposedException) { @@ -686,11 +751,6 @@ private AsyncCompletionResult BeginWriteCore(byte[] buffer, int offset, int size } public void EndWrite() - { - EndWriteCore(); - } - - private void EndWriteCore() { if (_asyncWriteException != null) { @@ -702,8 +762,7 @@ private void EndWriteCore() { if (!_asyncWritePending) { - Contract.Assert(false, "SocketConnection.EndWrite called with no write pending."); - throw new Exception("SocketConnection.EndWrite called with no write pending."); + throw Fx.AssertAndThrow("SocketConnection.EndWrite called with no write pending."); } SetUserToken(_asyncWriteEventArgs, null); @@ -716,6 +775,73 @@ private void EndWriteCore() } } + private void OnSendAsync(object sender, SocketAsyncEventArgs eventArgs) + { + Fx.Assert(eventArgs != null, "Argument 'eventArgs' cannot be NULL."); + CancelSendTimer(); + + try + { + HandleSendAsyncCompleted(); + Fx.Assert(eventArgs.BytesTransferred == _asyncWriteEventArgs.Count, "The socket SendAsync did not send all the bytes."); + } + catch (SocketException socketException) + { + _asyncWriteException = ConvertSendException(socketException, TimeSpan.MaxValue, _asyncSendTimeout); + } + catch (Exception exception) + { + if (Fx.IsFatal(exception)) + { + throw; + } + + _asyncWriteException = exception; + } + + FinishWrite(); + } + + private void HandleSendAsyncCompleted() + { + if (_asyncWriteEventArgs.SocketError == SocketError.Success) + { + return; + } + + throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SocketException((int)_asyncWriteEventArgs.SocketError)); + } + + // This method should be called inside ThisLock + private void DisposeWriteEventArgs() + { + if (_asyncWriteEventArgs != null) + { + _asyncWriteEventArgs.Completed -= s_onSocketSendCompleted; + _asyncWriteEventArgs.Dispose(); + } + } + + private void AbortWrite() + { + lock (ThisLock) + { + if (_asyncWritePending) + { + if (_closeState != CloseState.Closed) + { + SetUserToken(_asyncWriteEventArgs, null); + _asyncWritePending = false; + CancelSendTimer(); + } + else + { + DisposeWriteEventArgs(); + } + } + } + } + private void FinishWrite() { Action asyncWriteCallback = _asyncWriteCallback; @@ -728,11 +854,6 @@ private void FinishWrite() } public void Write(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout) - { - WriteCore(buffer, offset, size, immediate, timeout); - } - - private void WriteCore(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout) { // as per http://support.microsoft.com/default.aspx?scid=kb%3ben-us%3b201213 // we shouldn't write more than 64K synchronously to a socket @@ -759,7 +880,7 @@ private void WriteCore(byte[] buffer, int offset, int size, bool immediate, Time catch (SocketException socketException) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( - ConvertSendException(socketException, timeoutHelper.RemainingTime())); + ConvertSendException(socketException, timeoutHelper.RemainingTime(), _socketSyncSendTimeout)); } catch (ObjectDisposedException objectDisposedException) { @@ -775,6 +896,18 @@ private void WriteCore(byte[] buffer, int offset, int size, bool immediate, Time } } + private void TraceWriteStart(int size, bool async) + { + if (!async) + { + WcfEventSource.Instance.SocketWriteStart(_socket.GetHashCode(), size, RemoteEndpointAddress); + } + else + { + WcfEventSource.Instance.SocketAsyncWriteStart(_socket.GetHashCode(), size, RemoteEndpointAddress); + } + } + public void Write(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout, BufferManager bufferManager) { try @@ -791,13 +924,7 @@ public int Read(byte[] buffer, int offset, int size, TimeSpan timeout) { ConnectionUtilities.ValidateBufferBounds(buffer, offset, size); ThrowIfNotOpen(); - int bytesRead = ReadCore(buffer, offset, size, timeout, false); - if (WcfEventSource.Instance.SocketReadStopIsEnabled()) - { - TraceSocketReadStop(bytesRead, false); - } - - return bytesRead; + return ReadCore(buffer, offset, size, timeout, false); } private int ReadCore(byte[] buffer, int offset, int size, TimeSpan timeout, bool closing) @@ -812,7 +939,7 @@ private int ReadCore(byte[] buffer, int offset, int size, TimeSpan timeout, bool catch (SocketException socketException) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError( - ConvertReceiveException(socketException, timeoutHelper.RemainingTime())); + ConvertReceiveException(socketException, timeoutHelper.RemainingTime(), _socketSyncReceiveTimeout)); } catch (ObjectDisposedException objectDisposedException) { @@ -830,32 +957,26 @@ private int ReadCore(byte[] buffer, int offset, int size, TimeSpan timeout, bool return bytesRead; } - public virtual AsyncCompletionResult BeginRead(int offset, int size, TimeSpan timeout, - Action callback, object state) - { - ConnectionUtilities.ValidateBufferBounds(AsyncReadBufferSize, offset, size); - ThrowIfNotOpen(); - var completionResult = BeginReadCore(offset, size, timeout, callback, state); - if (completionResult == AsyncCompletionResult.Completed && WcfEventSource.Instance.SocketReadStopIsEnabled()) - { - TraceSocketReadStop(_asyncReadSize, true); - } - - return completionResult; - } - private void TraceSocketReadStop(int bytesRead, bool async) { if (!async) { - WcfEventSource.Instance.SocketReadStop((_socket != null) ? _socket.GetHashCode() : -1, bytesRead, RemoteEndpointAddressString); + WcfEventSource.Instance.SocketReadStop((_socket != null) ? _socket.GetHashCode() : -1, bytesRead, RemoteEndpointAddress); } else { - WcfEventSource.Instance.SocketAsyncReadStop((_socket != null) ? _socket.GetHashCode() : -1, bytesRead, RemoteEndpointAddressString); + WcfEventSource.Instance.SocketAsyncReadStop((_socket != null) ? _socket.GetHashCode() : -1, bytesRead, RemoteEndpointAddress); } } + public virtual AsyncCompletionResult BeginRead(int offset, int size, TimeSpan timeout, + Action callback, object state) + { + ConnectionUtilities.ValidateBufferBounds(AsyncReadBufferSize, offset, size); + ThrowIfNotOpen(); + return BeginReadCore(offset, size, timeout, callback, state); + } + private AsyncCompletionResult BeginReadCore(int offset, int size, TimeSpan timeout, Action callback, object state) { @@ -874,27 +995,48 @@ private AsyncCompletionResult BeginReadCore(int offset, int size, TimeSpan timeo try { - if (offset != _asyncReadEventArgs.Offset || - size != _asyncReadEventArgs.Count) + if (_socket.UseOnlyOverlappedIO) { - _asyncReadEventArgs.SetBuffer(offset, size); - } + // ReceiveAsync does not respect UseOnlyOverlappedIO but BeginReceive does. + IAsyncResult result = _socket.BeginReceive(AsyncReadBuffer, offset, size, SocketFlags.None, s_onReceiveCompleted, this); + + if (!result.CompletedSynchronously) + { + abortRead = false; + return AsyncCompletionResult.Queued; + } - if (ReceiveAsync()) + _asyncReadSize = _socket.EndReceive(result); + } + else { - abortRead = false; - return AsyncCompletionResult.Queued; + if (offset != _asyncReadEventArgs.Offset || + size != _asyncReadEventArgs.Count) + { + _asyncReadEventArgs.SetBuffer(offset, size); + } + + if (ReceiveAsync()) + { + abortRead = false; + return AsyncCompletionResult.Queued; + } + + HandleReceiveAsyncCompleted(); + _asyncReadSize = _asyncReadEventArgs.BytesTransferred; } - HandleReceiveAsyncCompleted(); - _asyncReadSize = _asyncReadEventArgs.BytesTransferred; + if (WcfEventSource.Instance.SocketReadStopIsEnabled()) + { + TraceSocketReadStop(_asyncReadSize, true); + } abortRead = false; return AsyncCompletionResult.Completed; } catch (SocketException socketException) { - throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(ConvertReceiveException(socketException, TimeSpan.MaxValue)); + throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(ConvertReceiveException(socketException, TimeSpan.MaxValue, _asyncReceiveTimeout)); } catch (ObjectDisposedException objectDisposedException) { @@ -922,19 +1064,61 @@ private bool ReceiveAsync() return _socket.ReceiveAsync(_asyncReadEventArgs); } + private void OnReceive(IAsyncResult result) + { + CancelReceiveTimer(); + if (result.CompletedSynchronously) + { + return; + } + + try + { + _asyncReadSize = _socket.EndReceive(result); + + if (WcfEventSource.Instance.SocketReadStopIsEnabled()) + { + TraceSocketReadStop(_asyncReadSize, true); + } + } + catch (SocketException socketException) + { + _asyncReadException = ConvertReceiveException(socketException, TimeSpan.MaxValue, _asyncReceiveTimeout); + } + catch (ObjectDisposedException objectDisposedException) + { + _asyncReadException = ConvertObjectDisposedException(objectDisposedException, TransferOperation.Read); + } + catch (Exception exception) + { + if (Fx.IsFatal(exception)) + { + throw; + } + _asyncReadException = exception; + } + + FinishRead(); + } + private void OnReceiveAsync(object sender, SocketAsyncEventArgs eventArgs) { - Contract.Assert(eventArgs != null, "Argument 'eventArgs' cannot be NULL."); + Fx.Assert(eventArgs != null, "Argument 'eventArgs' cannot be NULL."); CancelReceiveTimer(); try { HandleReceiveAsyncCompleted(); _asyncReadSize = eventArgs.BytesTransferred; + + if (WcfEventSource.Instance.SocketReadStopIsEnabled()) + { + TraceSocketReadStop(_asyncReadSize, true); + } } catch (SocketException socketException) { - _asyncReadException = ConvertReceiveException(socketException, TimeSpan.MaxValue); + _asyncReadException = ConvertReceiveException(socketException, TimeSpan.MaxValue, _asyncReceiveTimeout); } catch (Exception exception) { @@ -958,14 +1142,8 @@ private void HandleReceiveAsyncCompleted() throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SocketException((int)_asyncReadEventArgs.SocketError)); } - private void FinishRead() { - if (_asyncReadException != null && WcfEventSource.Instance.SocketReadStopIsEnabled()) - { - TraceSocketReadStop(_asyncReadSize, true); - } - Action asyncReadCallback = _asyncReadCallback; object asyncReadState = _asyncReadState; @@ -977,12 +1155,6 @@ private void FinishRead() // Both BeginRead/ReadAsync paths completed themselves. EndRead's only job is to deliver the result. public int EndRead() - { - return EndReadCore(); - } - - // Both BeginRead/ReadAsync paths completed themselves. EndRead's only job is to deliver the result. - private int EndReadCore() { if (_asyncReadException != null) { @@ -994,8 +1166,7 @@ private int EndReadCore() { if (!_asyncReadPending) { - Contract.Assert(false, "SocketConnection.EndRead called with no read pending."); - throw new Exception("SocketConnection.EndRead called with no read pending."); + throw Fx.AssertAndThrow("SocketConnection.EndRead called with no read pending."); } SetUserToken(_asyncReadEventArgs, null); @@ -1023,108 +1194,15 @@ private void DisposeReadEventArgs() TryReturnReadBuffer(); } - // This method should be called inside ThisLock - private void DisposeReceiveTimer() - { - if (_receiveTimer != null) - { - _receiveTimer.Dispose(); - } - } - - private void OnSendAsync(object sender, SocketAsyncEventArgs eventArgs) - { - Contract.Assert(eventArgs != null, "Argument 'eventArgs' cannot be NULL."); - CancelSendTimer(); - - try - { - HandleSendAsyncCompleted(); - Contract.Assert(eventArgs.BytesTransferred == _asyncWriteEventArgs.Count, "The socket SendAsync did not send all the bytes."); - } - catch (SocketException socketException) - { - _asyncWriteException = ConvertSendException(socketException, TimeSpan.MaxValue); - } - catch (Exception exception) - { - if (Fx.IsFatal(exception)) - { - throw; - } - - _asyncWriteException = exception; - } - - FinishWrite(); - } - - private void HandleSendAsyncCompleted() - { - if (_asyncWriteEventArgs.SocketError == SocketError.Success) - { - return; - } - - throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SocketException((int)_asyncWriteEventArgs.SocketError)); - } - - // This method should be called inside ThisLock - private void DisposeWriteEventArgs() - { - if (_asyncWriteEventArgs != null) - { - _asyncWriteEventArgs.Completed -= s_onSocketSendCompleted; - _asyncWriteEventArgs.Dispose(); - } - } - - // This method should be called inside ThisLock - private void DisposeSendTimer() - { - if (_sendTimer != null) - { - _sendTimer.Dispose(); - } - } - - private void AbortWrite() - { - lock (ThisLock) - { - if (_asyncWritePending) - { - if (_closeState != CloseState.Closed) - { - SetUserToken(_asyncWriteEventArgs, null); - _asyncWritePending = false; - CancelSendTimer(); - } - else - { - DisposeWriteEventArgs(); - } - } - } - } - - // This method should be called inside ThisLock - private void ReturnReadBuffer() - { - // We release the buffer only if there is no outstanding I/O - TryReturnReadBuffer(); - } - - // This method should be called inside ThisLock private void TryReturnReadBuffer() { // The buffer must not be returned and nulled when an abort occurs. Since the buffer // is also accessed by higher layers, code that has not yet realized the stack is // aborted may be attempting to read from the buffer. - if (AsyncReadBuffer != null && !_aborted) + if (_readBuffer != null && !_aborted) { - _connectionBufferPool.Return(AsyncReadBuffer); - AsyncReadBuffer = null; + _connectionBufferPool.Return(_readBuffer); + _readBuffer = null; } } @@ -1164,7 +1242,7 @@ private void SetReadTimeout(TimeSpan timeout, bool synchronous, bool closing) new TimeoutException(SR.Format(SR.TcpConnectionTimedOut, timeout))); } - if (UpdateTimeout(_receiveTimeout, timeout)) + if (ShouldUpdateTimeout(_socketSyncReceiveTimeout, timeout)) { lock (ThisLock) { @@ -1174,12 +1252,12 @@ private void SetReadTimeout(TimeSpan timeout, bool synchronous, bool closing) } _socket.ReceiveTimeout = TimeoutHelper.ToMilliseconds(timeout); } - _receiveTimeout = timeout; + _socketSyncReceiveTimeout = timeout; } } else { - _receiveTimeout = timeout; + _asyncReceiveTimeout = timeout; if (timeout == TimeSpan.MaxValue) { CancelReceiveTimer(); @@ -1204,19 +1282,19 @@ private void SetWriteTimeout(TimeSpan timeout, bool synchronous) new TimeoutException(SR.Format(SR.TcpConnectionTimedOut, timeout))); } - if (UpdateTimeout(_sendTimeout, timeout)) + if (ShouldUpdateTimeout(_socketSyncSendTimeout, timeout)) { lock (ThisLock) { ThrowIfNotOpen(); _socket.SendTimeout = TimeoutHelper.ToMilliseconds(timeout); } - _sendTimeout = timeout; + _socketSyncSendTimeout = timeout; } } else { - _sendTimeout = timeout; + _asyncSendTimeout = timeout; if (timeout == TimeSpan.MaxValue) { CancelSendTimer(); @@ -1228,7 +1306,7 @@ private void SetWriteTimeout(TimeSpan timeout, bool synchronous) } } - private bool UpdateTimeout(TimeSpan oldTimeout, TimeSpan newTimeout) + private bool ShouldUpdateTimeout(TimeSpan oldTimeout, TimeSpan newTimeout) { if (oldTimeout == newTimeout) { @@ -1253,7 +1331,7 @@ private void EnsureReadEventArgs() } _asyncReadEventArgs = new SocketAsyncEventArgs(); - _asyncReadEventArgs.SetBuffer(AsyncReadBuffer, 0, AsyncReadBuffer.Length); + _asyncReadEventArgs.SetBuffer(_readBuffer, 0, _readBuffer.Length); _asyncReadEventArgs.Completed += s_onReceiveAsyncCompleted; } } @@ -1274,11 +1352,6 @@ private void EnsureWriteEventArgs() } } - public object GetCoreTransport() - { - return _socket; - } - private enum CloseState { Open, @@ -1313,7 +1386,7 @@ private IConnection CreateConnection(IPAddress address, int port) AddressFamily addressFamily = address.AddressFamily; socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp); socket.Connect(new IPEndPoint(address, port)); - return new SocketConnection(socket, _connectionBufferPool); + return new SocketConnection(socket, _connectionBufferPool, false); } catch { @@ -1330,7 +1403,7 @@ private async Task CreateConnectionAsync(IPAddress address, int por AddressFamily addressFamily = address.AddressFamily; socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp); await socket.ConnectAsync(new IPEndPoint(address, port)); - return new SocketConnection(socket, _connectionBufferPool); + return new SocketConnection(socket, _connectionBufferPool, false); } catch {