diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs deleted file mode 100644 index 6cc6918ea..000000000 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs +++ /dev/null @@ -1,50 +0,0 @@ -#if NET6_0_OR_GREATER - -using System; -using System.Diagnostics; -using System.Net.Sockets; -using System.Threading; -using System.Threading.Tasks; - -namespace Renci.SshNet.Abstractions -{ - internal static partial class SocketAbstraction - { - public static ValueTask ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) - { - return socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken); - } - - public static ValueTask SendAsync(Socket socket, ReadOnlyMemory data, CancellationToken cancellationToken = default) - { - Debug.Assert(socket != null); - Debug.Assert(data.Length > 0); - - if (cancellationToken.IsCancellationRequested) - { - return ValueTask.FromCanceled(cancellationToken); - } - - return SendAsyncCore(socket, data, cancellationToken); - - static async ValueTask SendAsyncCore(Socket socket, ReadOnlyMemory data, CancellationToken cancellationToken) - { - do - { - try - { - var bytesSent = await socket.SendAsync(data, SocketFlags.None, cancellationToken).ConfigureAwait(false); - data = data.Slice(bytesSent); - } - catch (SocketException ex) when (IsErrorResumable(ex.SocketErrorCode)) - { - // Buffer may be full; attempt a short delay and retry - await Task.Delay(30, cancellationToken).ConfigureAwait(false); - } - } - while (data.Length > 0); - } - } - } -} -#endif // NET6_0_OR_GREATER diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index e1a12362e..ee094c3c3 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -6,84 +6,25 @@ using System.Threading.Tasks; using Renci.SshNet.Common; -using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.Abstractions { internal static partial class SocketAbstraction { - public static bool CanRead(Socket socket) - { - if (socket.Connected) - { - return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; - } - - return false; - } - - /// - /// Returns a value indicating whether the specified can be used - /// to send data. - /// - /// The to check. - /// - /// if can be written to; otherwise, . - /// - public static bool CanWrite(Socket socket) - { - if (socket != null && socket.Connected) - { - return socket.Poll(-1, SelectMode.SelectWrite); - } - - return false; - } - - public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout) - { - var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: true); - return socket; - } - public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { - ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false); - } - - public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken) - { - await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false); - } - - private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket) - { - var connectCompleted = new ManualResetEvent(initialState: false); - var args = new SocketAsyncEventArgs - { - UserToken = connectCompleted, - RemoteEndPoint = remoteEndpoint - }; - args.Completed += ConnectCompleted; + using var connectCompleted = new ManualResetEventSlim(initialState: false); + using var args = new SocketAsyncEventArgs + { + RemoteEndPoint = remoteEndpoint + }; + args.Completed += (_, _) => connectCompleted.Set(); if (socket.ConnectAsync(args)) { - if (!connectCompleted.WaitOne(connectTimeout)) + if (!connectCompleted.Wait(connectTimeout)) { - // avoid ObjectDisposedException in ConnectCompleted - args.Completed -= ConnectCompleted; - if (ownsSocket) - { - // dispose Socket - socket.Dispose(); - } - - // dispose ManualResetEvent - connectCompleted.Dispose(); - - // dispose SocketAsyncEventArgs - args.Dispose(); + socket.Dispose(); throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", @@ -91,61 +32,12 @@ private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSp } } - // dispose ManualResetEvent - connectCompleted.Dispose(); - if (args.SocketError != SocketError.Success) { var socketError = (int) args.SocketError; - if (ownsSocket) - { - // dispose Socket - socket.Dispose(); - } - - // dispose SocketAsyncEventArgs - args.Dispose(); - throw new SocketException(socketError); } - - // dispose SocketAsyncEventArgs - args.Dispose(); - } - - public static void ClearReadBuffer(Socket socket) - { - var timeout = TimeSpan.FromMilliseconds(500); - var buffer = new byte[256]; - int bytesReceived; - - do - { - bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout); - } - while (bytesReceived > 0); - } - - public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout) - { - socket.ReceiveTimeout = timeout.AsTimeout(nameof(timeout)); - - try - { - return socket.Receive(buffer, offset, size, SocketFlags.None); - } - catch (SocketException ex) - { - if (ex.SocketErrorCode == SocketError.TimedOut) - { - throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, - "Socket read operation has timed out after {0:F0} milliseconds.", - timeout.TotalMilliseconds)); - } - - throw; - } } public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action processReceivedBytesAction) @@ -167,11 +59,6 @@ public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int } catch (SocketException ex) { - if (IsErrorResumable(ex.SocketErrorCode)) - { - continue; - } - #pragma warning disable IDE0010 // Add missing cases switch (ex.SocketErrorCode) { @@ -212,41 +99,6 @@ public static int ReadByte(Socket socket, TimeSpan timeout) return buffer[0]; } - /// - /// Sends a byte using the specified . - /// - /// The to write to. - /// The value to send. - /// The write failed. - public static void SendByte(Socket socket, byte value) - { - var buffer = new[] { value }; - Send(socket, buffer, 0, 1); - } - - /// - /// Receives data from a bound . - /// - /// The to read from. - /// The number of bytes to receive. - /// Specifies the amount of time after which the call will time out. - /// - /// The bytes received. - /// - /// - /// If no data is available for reading, the method will - /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the - /// call will throw a . - /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the - /// method will complete immediately and throw a . - /// - public static byte[] Read(Socket socket, int size, TimeSpan timeout) - { - var buffer = new byte[size]; - _ = Read(socket, buffer, 0, size, timeout); - return buffer; - } - /// /// Receives data from a bound into a receive buffer. /// @@ -264,10 +116,6 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout) /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the /// call will throw a . /// - /// - /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the - /// method will complete immediately and throw a . - /// /// public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout) { @@ -288,22 +136,12 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS totalBytesRead += bytesRead; } - catch (SocketException ex) + catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut) { - if (IsErrorResumable(ex.SocketErrorCode)) - { - ThreadAbstraction.Sleep(30); - continue; - } - - if (ex.SocketErrorCode == SocketError.TimedOut) - { - throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, - "Socket read operation has timed out after {0:F0} milliseconds.", - readTimeout.TotalMilliseconds)); - } - - throw; + throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, + "Socket read operation has timed out after {0:F0} milliseconds.", + readTimeout.TotalMilliseconds), + ex); } } while (totalBytesRead < totalBytesToRead); @@ -311,71 +149,34 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS return totalBytesRead; } -#if NET6_0_OR_GREATER == false - public static Task ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) - { - return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken); - } -#endif - - public static void Send(Socket socket, byte[] data) + public static async ValueTask ReadAsync(Socket socket, byte[] buffer, int offset, int size, CancellationToken cancellationToken) { - Send(socket, data, 0, data.Length); - } - - public static void Send(Socket socket, byte[] data, int offset, int size) - { - var totalBytesSent = 0; // how many bytes are already sent - var totalBytesToSend = size; + var totalBytesRead = 0; + var totalBytesToRead = size; do { try { - var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None); - if (bytesSent == 0) + var bytesRead = await socket.ReceiveAsync(new ArraySegment(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead), SocketFlags.None, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) { - throw new SshConnectionException("An established connection was aborted by the server.", - DisconnectReason.ConnectionLost); + return 0; } - totalBytesSent += bytesSent; + totalBytesRead += bytesRead; } - catch (SocketException ex) + catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut) { - if (IsErrorResumable(ex.SocketErrorCode)) - { - // socket buffer is probably full, wait and try again - ThreadAbstraction.Sleep(30); - } - else - { - throw; // any serious error occurr - } + throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, + "Socket read operation has timed out after {0:F0} milliseconds.", + socket.ReceiveTimeout), + ex); } } - while (totalBytesSent < totalBytesToSend); - } - - public static bool IsErrorResumable(SocketError socketError) - { -#pragma warning disable IDE0010 // Add missing cases - switch (socketError) - { - case SocketError.WouldBlock: - case SocketError.IOPending: - case SocketError.NoBufferSpaceAvailable: - return true; - default: - return false; - } -#pragma warning restore IDE0010 // Add missing cases - } + while (totalBytesRead < totalBytesToRead); - private static void ConnectCompleted(object sender, SocketAsyncEventArgs e) - { - var eventWaitHandle = (ManualResetEvent) e.UserToken; - _ = eventWaitHandle?.Set(); + return totalBytesRead; } } } diff --git a/src/Renci.SshNet/Abstractions/SocketExtensions.cs b/src/Renci.SshNet/Abstractions/SocketExtensions.cs index 2c34c899c..67949eaa4 100644 --- a/src/Renci.SshNet/Abstractions/SocketExtensions.cs +++ b/src/Renci.SshNet/Abstractions/SocketExtensions.cs @@ -1,134 +1,126 @@ -#if !NET6_0_OR_GREATER +#if !NET +#if NETFRAMEWORK || NETSTANDARD2_0 using System; +#endif using System.Net; using System.Net.Sockets; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace Renci.SshNet.Abstractions { - // Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/ internal static class SocketExtensions { - private sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, INotifyCompletion + public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken) { - private static readonly Action SENTINEL = () => { }; + cancellationToken.ThrowIfCancellationRequested(); - private bool _isCancelled; - private Action _continuationAction; + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - public AwaitableSocketAsyncEventArgs() + using var args = new SocketAsyncEventArgs { - Completed += (sender, e) => SetCompleted(); - } + RemoteEndPoint = remoteEndpoint + }; + args.Completed += (_, _) => tcs.TrySetResult(null); - public AwaitableSocketAsyncEventArgs ExecuteAsync(Func func) + if (socket.ConnectAsync(args)) { - if (!func(this)) +#if NETSTANDARD2_1 + await using (cancellationToken.Register(() => +#else + using (cancellationToken.Register(() => +#endif { - SetCompleted(); - } - - return this; - } - - public void SetCompleted() - { - IsCompleted = true; - - var continuation = _continuationAction ?? Interlocked.CompareExchange(ref _continuationAction, SENTINEL, comparand: null); - if (continuation is not null) + if (tcs.TrySetCanceled(cancellationToken)) + { + socket.Dispose(); + } + }, + useSynchronizationContext: false) +#if NETSTANDARD2_1 + .ConfigureAwait(false) +#endif + ) { - continuation(); + _ = await tcs.Task.ConfigureAwait(false); } } - public void SetCancelled() + if (args.SocketError != SocketError.Success) { - _isCancelled = true; - SetCompleted(); + throw new SocketException((int) args.SocketError); } + } -#pragma warning disable S1144 // Unused private types or members should be removed - public AwaitableSocketAsyncEventArgs GetAwaiter() -#pragma warning restore S1144 // Unused private types or members should be removed - { - return this; - } +#if NETFRAMEWORK || NETSTANDARD2_0 + public static async ValueTask ReceiveAsync(this Socket socket, ArraySegment buffer, SocketFlags socketFlags, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); - public bool IsCompleted { get; private set; } + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - void INotifyCompletion.OnCompleted(Action continuation) - { - if (_continuationAction == SENTINEL || Interlocked.CompareExchange(ref _continuationAction, continuation, comparand: null) == SENTINEL) - { - // We have already completed; run continuation asynchronously - _ = Task.Run(continuation); - } - } + using var args = new SocketAsyncEventArgs(); + args.SocketFlags = socketFlags; + args.Completed += (_, _) => tcs.TrySetResult(null); + args.SetBuffer(buffer.Array, buffer.Offset, buffer.Count); -#pragma warning disable S1144 // Unused private types or members should be removed - public void GetResult() -#pragma warning restore S1144 // Unused private types or members should be removed + if (socket.ReceiveAsync(args)) { - if (_isCancelled) + using (cancellationToken.Register(() => { - throw new TaskCanceledException(); - } - - if (!IsCompleted) - { - // We don't support sync/async - throw new InvalidOperationException("The asynchronous operation has not yet completed."); - } - - if (SocketError != SocketError.Success) + if (tcs.TrySetCanceled(cancellationToken)) + { + socket.Dispose(); + } + }, + useSynchronizationContext: false)) { - throw new SocketException((int)SocketError); + _ = await tcs.Task.ConfigureAwait(false); } } - } - public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - using (var args = new AwaitableSocketAsyncEventArgs()) + if (args.SocketError != SocketError.Success) { - args.RemoteEndPoint = remoteEndpoint; - -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs)o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - await args.ExecuteAsync(socket.ConnectAsync); - } + throw new SocketException((int) args.SocketError); } + + return args.BytesTransferred; } - public static async Task ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken) + public static async ValueTask SendAsync(this Socket socket, byte[] buffer, SocketFlags socketFlags, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (var args = new AwaitableSocketAsyncEventArgs()) - { - args.SetBuffer(buffer, offset, length); + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER + using var args = new SocketAsyncEventArgs(); + args.SocketFlags = socketFlags; + args.Completed += (_, _) => tcs.TrySetResult(null); + args.SetBuffer(buffer, 0, buffer.Length); + + if (socket.SendAsync(args)) + { + using (cancellationToken.Register(() => + { + if (tcs.TrySetCanceled(cancellationToken)) + { + socket.Dispose(); + } + }, + useSynchronizationContext: false)) { - await args.ExecuteAsync(socket.ReceiveAsync); + _ = await tcs.Task.ConfigureAwait(false); } + } - return args.BytesTransferred; + if (args.SocketError != SocketError.Success) + { + throw new SocketException((int) args.SocketError); } + + return args.BytesTransferred; } +#endif // NETFRAMEWORK || NETSTANDARD2_0 } } #endif diff --git a/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs b/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs index 6c521bce2..3dc7f1e4c 100644 --- a/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs +++ b/src/Renci.SshNet/Channels/ChannelDirectTcpip.cs @@ -201,7 +201,7 @@ protected override void OnData(byte[] data) { if (_socket.IsConnected()) { - SocketAbstraction.Send(_socket, data, 0, data.Length); + _ = _socket.Send(data); } } } diff --git a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs index a8382015a..d881a7fa7 100644 --- a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs +++ b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs @@ -3,6 +3,7 @@ using System.Net.Sockets; using Renci.SshNet.Abstractions; using Renci.SshNet.Common; +using Renci.SshNet.Connection; using Renci.SshNet.Messages.Connection; namespace Renci.SshNet.Channels @@ -13,6 +14,7 @@ namespace Renci.SshNet.Channels internal sealed class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTcpip { private readonly object _socketShutdownAndCloseLock = new object(); + private readonly ISocketFactory _socketFactory; private Socket _socket; private IForwardedPort _forwardedPort; @@ -20,6 +22,7 @@ internal sealed class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTc /// Initializes a new instance of the class. /// /// The session. + /// The socket factory. /// The local channel number. /// Size of the window. /// Size of the packet. @@ -27,6 +30,7 @@ internal sealed class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTc /// The window size of the remote party. /// The maximum size of a data packet that we can send to the remote party. internal ChannelForwardedTcpip(ISession session, + ISocketFactory socketFactory, uint localChannelNumber, uint localWindowSize, uint localPacketSize, @@ -41,6 +45,7 @@ internal ChannelForwardedTcpip(ISession session, remoteWindowSize, remotePacketSize) { + _socketFactory = socketFactory; } /// @@ -72,7 +77,9 @@ public void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort) // Try to connect to the socket try { - _socket = SocketAbstraction.Connect(remoteEndpoint, ConnectionInfo.Timeout); + _socket = _socketFactory.Create(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + SocketAbstraction.Connect(_socket, remoteEndpoint, ConnectionInfo.Timeout); // Send channel open confirmation message SendMessage(new ChannelOpenConfirmationMessage(RemoteChannelNumber, LocalWindowSize, LocalPacketSize, LocalChannelNumber)); @@ -201,7 +208,7 @@ protected override void OnData(byte[] data) var socket = _socket; if (socket.IsConnected()) { - SocketAbstraction.Send(socket, data, 0, data.Length); + _ = socket.Send(data); } } } diff --git a/src/Renci.SshNet/Common/Extensions.cs b/src/Renci.SshNet/Common/Extensions.cs index 80fa8323d..2eb89e3e4 100644 --- a/src/Renci.SshNet/Common/Extensions.cs +++ b/src/Renci.SshNet/Common/Extensions.cs @@ -5,7 +5,7 @@ using System.Net; using System.Net.Sockets; using System.Text; -using Renci.SshNet.Abstractions; + using Renci.SshNet.Messages; namespace Renci.SshNet.Common @@ -336,22 +336,17 @@ public static byte[] Concat(this byte[] first, byte[] second) internal static bool CanRead(this Socket socket) { - return SocketAbstraction.CanRead(socket); + return socket.Connected && socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; } internal static bool CanWrite(this Socket socket) { - return SocketAbstraction.CanWrite(socket); + return socket is not null && socket.Connected && socket.Poll(-1, SelectMode.SelectWrite); } internal static bool IsConnected(this Socket socket) { - if (socket is null) - { - return false; - } - - return socket.Connected; + return socket is not null && socket.Connected; } } } diff --git a/src/Renci.SshNet/Connection/ConnectorBase.cs b/src/Renci.SshNet/Connection/ConnectorBase.cs index c1b30fb8f..43e4e4bbc 100644 --- a/src/Renci.SshNet/Connection/ConnectorBase.cs +++ b/src/Renci.SshNet/Connection/ConnectorBase.cs @@ -86,7 +86,7 @@ protected async Task SocketConnectAsync(string host, int port, Cancellat var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try { - await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false); + await socket.ConnectAsync(ep, cancellationToken).ConfigureAwait(false); const int socketBufferSize = 2 * Session.MaximumSshPacketSize; socket.SendBufferSize = socketBufferSize; diff --git a/src/Renci.SshNet/Connection/HttpConnector.cs b/src/Renci.SshNet/Connection/HttpConnector.cs index afbaf0f01..72f502b00 100644 --- a/src/Renci.SshNet/Connection/HttpConnector.cs +++ b/src/Renci.SshNet/Connection/HttpConnector.cs @@ -3,6 +3,7 @@ using System.Globalization; using System.Net; using System.Net.Sockets; +using System.Text; using System.Text.RegularExpressions; using Renci.SshNet.Abstractions; @@ -41,7 +42,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke var httpResponseRe = new Regex(@"HTTP/(?\d[.]\d) (?\d{3}) (?.+)$"); var httpHeaderRe = new Regex(@"(?[^\[\]()<>@,;:\""/?={} \t]+):(?.+)?"); - SocketAbstraction.Send(socket, SshData.Ascii.GetBytes(string.Format(CultureInfo.InvariantCulture, + _ = socket.Send(Encoding.ASCII.GetBytes(string.Format(CultureInfo.InvariantCulture, "CONNECT {0}:{1} HTTP/1.0\r\n", connectionInfo.Host, connectionInfo.Port))); @@ -51,11 +52,11 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke { var authorization = string.Format(CultureInfo.InvariantCulture, "Proxy-Authorization: Basic {0}\r\n", - Convert.ToBase64String(SshData.Ascii.GetBytes($"{connectionInfo.ProxyUsername}:{connectionInfo.ProxyPassword}"))); - SocketAbstraction.Send(socket, SshData.Ascii.GetBytes(authorization)); + Convert.ToBase64String(Encoding.ASCII.GetBytes($"{connectionInfo.ProxyUsername}:{connectionInfo.ProxyPassword}"))); + _ = socket.Send(Encoding.ASCII.GetBytes(authorization)); } - SocketAbstraction.Send(socket, SshData.Ascii.GetBytes("\r\n")); + _ = socket.Send(Encoding.ASCII.GetBytes("\r\n")); HttpStatusCode? statusCode = null; var contentLength = 0; diff --git a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs index b14da93c0..607249e2b 100644 --- a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs +++ b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs @@ -38,7 +38,7 @@ public SshIdentification Start(string clientVersion, Socket socket, TimeSpan tim { // Immediately send the identification string since the spec states both sides MUST send an identification string // when the connection has been established - SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A")); + _ = socket.Send(Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A")); var bytesReceived = new List(); @@ -81,11 +81,7 @@ public async Task StartAsync(string clientVersion, Socket soc { // Immediately send the identification string since the spec states both sides MUST send an identification string // when the connection has been established -#if NET6_0_OR_GREATER - await SocketAbstraction.SendAsync(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), cancellationToken).ConfigureAwait(false); -#else - SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A")); -#endif // NET6_0_OR_GREATER + _ = await socket.SendAsync(Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), SocketFlags.None, cancellationToken).ConfigureAwait(false); var bytesReceived = new List(); @@ -191,7 +187,7 @@ private static async Task SocketReadLineAsync(Socket socket, List // to be processed by subsequent invocations. while (true) { - var bytesRead = await SocketAbstraction.ReadAsync(socket, data, cancellationToken).ConfigureAwait(false); + var bytesRead = await SocketAbstraction.ReadAsync(socket, data, 0, data.Length, cancellationToken).ConfigureAwait(false); if (bytesRead == 0) { throw new SshConnectionException("The connection was closed by the remote host."); diff --git a/src/Renci.SshNet/Connection/Socks4Connector.cs b/src/Renci.SshNet/Connection/Socks4Connector.cs index e3e9800f0..f8b0926b2 100644 --- a/src/Renci.SshNet/Connection/Socks4Connector.cs +++ b/src/Renci.SshNet/Connection/Socks4Connector.cs @@ -3,7 +3,6 @@ using System.Net.Sockets; using System.Text; -using Renci.SshNet.Abstractions; using Renci.SshNet.Common; namespace Renci.SshNet.Connection @@ -29,7 +28,7 @@ public Socks4Connector(ISocketFactory socketFactory) protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socket socket) { var connectionRequest = CreateSocks4ConnectionRequest(connectionInfo.Host, (ushort)connectionInfo.Port, connectionInfo.ProxyUsername); - SocketAbstraction.Send(socket, connectionRequest); + _ = socket.Send(connectionRequest); // Read reply version if (SocketReadByte(socket, connectionInfo.Timeout) != 0x00) diff --git a/src/Renci.SshNet/Connection/Socks5Connector.cs b/src/Renci.SshNet/Connection/Socks5Connector.cs index ecd286e00..8c42e591e 100644 --- a/src/Renci.SshNet/Connection/Socks5Connector.cs +++ b/src/Renci.SshNet/Connection/Socks5Connector.cs @@ -41,7 +41,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke // Username/Password authentication 0x02 }; - SocketAbstraction.Send(socket, greeting); + _ = socket.Send(greeting); var socksVersion = SocketReadByte(socket); if (socksVersion != 0x05) @@ -60,10 +60,11 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke var authenticationRequest = CreateSocks5UserNameAndPasswordAuthenticationRequest(connectionInfo.ProxyUsername, connectionInfo.ProxyPassword); // Send authentication request - SocketAbstraction.Send(socket, authenticationRequest); + _ = socket.Send(authenticationRequest); // Read authentication result - var authenticationResult = SocketAbstraction.Read(socket, 2, connectionInfo.Timeout); + var authenticationResult = new byte[2]; + _ = SocketAbstraction.Read(socket, authenticationResult, 0, authenticationResult.Length, connectionInfo.Timeout); if (authenticationResult[0] != 0x01) { @@ -83,7 +84,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke } var connectionRequest = CreateSocks5ConnectionRequest(connectionInfo.Host, (ushort) connectionInfo.Port); - SocketAbstraction.Send(socket, connectionRequest); + _ = socket.Send(connectionRequest); // Read Server SOCKS5 version if (SocketReadByte(socket) != 5) diff --git a/src/Renci.SshNet/ForwardedPortDynamic.cs b/src/Renci.SshNet/ForwardedPortDynamic.cs index 2a2c45f2c..b6705eb55 100644 --- a/src/Renci.SshNet/ForwardedPortDynamic.cs +++ b/src/Renci.SshNet/ForwardedPortDynamic.cs @@ -503,18 +503,18 @@ private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan t channel.Open(host, port, this, socket); - SocketAbstraction.SendByte(socket, 0x00); + _ = socket.Send([0x00]); if (channel.IsOpen) { - SocketAbstraction.SendByte(socket, 0x5a); - SocketAbstraction.Send(socket, portBuffer, 0, portBuffer.Length); - SocketAbstraction.Send(socket, ipBuffer, 0, ipBuffer.Length); + _ = socket.Send([0x5a]); + _ = socket.Send(portBuffer); + _ = socket.Send(ipBuffer); return true; } // signal that request was rejected or failed - SocketAbstraction.SendByte(socket, 0x5b); + _ = socket.Send([0x5b]); return false; } @@ -538,12 +538,12 @@ private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan t { // no user authentication is one of the authentication methods supported // by the SOCKS client - SocketAbstraction.Send(socket, new byte[] { 0x05, 0x00 }, 0, 2); + _ = socket.Send([0x05, 0x00]); } else { // the SOCKS client requires authentication, which we currently do not support - SocketAbstraction.Send(socket, new byte[] { 0x05, 0xFF }, 0, 2); + _ = socket.Send([0x05, 0xFF]); // we continue business as usual but expect the client to close the connection // so one of the subsequent reads should return -1 signaling that the client @@ -610,7 +610,7 @@ private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan t var socksReply = CreateSocks5Reply(channel.IsOpen); - SocketAbstraction.Send(socket, socksReply, 0, socksReply.Length); + _ = socket.Send(socksReply); return true; } diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 0c067ec2d..187e68686 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -1144,7 +1144,7 @@ private void SendPacket(byte[] packet, int offset, int length) throw new SshConnectionException("Client not connected."); } - SocketAbstraction.Send(_socket, packet, offset, length); + _ = _socket.Send(packet, offset, length, SocketFlags.None); } finally { @@ -2216,6 +2216,7 @@ IChannelForwardedTcpip ISession.CreateChannelForwardedTcpip(uint remoteChannelNu uint remoteChannelDataPacketSize) { return new ChannelForwardedTcpip(this, + _socketFactory, NextChannelNumber, InitialLocalWindowSize, LocalChannelDataPacketSize, diff --git a/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs b/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs index e50858c33..69d9bf1af 100644 --- a/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs +++ b/test/Renci.SshNet.IntegrationTests/Common/Socks5Handler.cs @@ -3,18 +3,21 @@ using Renci.SshNet.Abstractions; using Renci.SshNet.Common; +using Renci.SshNet.Connection; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.IntegrationTests.Common { class Socks5Handler { + private readonly ISocketFactory _socketFactory; private readonly IPEndPoint _proxyEndPoint; private readonly string _userName; private readonly string _password; public Socks5Handler(IPEndPoint proxyEndPoint, string userName, string password) { + _socketFactory = new SocketFactory(); _proxyEndPoint = proxyEndPoint; _userName = userName; _password = password; @@ -52,17 +55,19 @@ public Socket Connect(string host, int port) private Socket Connect(byte[] addressBytes, int port) { - var socket = SocketAbstraction.Connect(_proxyEndPoint, TimeSpan.FromSeconds(5)); + var socket = _socketFactory.Create(_proxyEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + SocketAbstraction.Connect(socket, _proxyEndPoint, TimeSpan.FromSeconds(5)); // Send socks version number - SocketWriteByte(socket, 0x05); + _ = socket.Send([0x05]); // Send number of supported authentication methods - SocketWriteByte(socket, 0x02); + _ = socket.Send([0x02]); // Send supported authentication methods - SocketWriteByte(socket, 0x00); // No authentication - SocketWriteByte(socket, 0x02); // Username/Password + _ = socket.Send([0x00]); // No authentication + _ = socket.Send([0x02]); // Username/Password var socksVersion = SocketReadByte(socket); if (socksVersion != 0x05) @@ -78,7 +83,7 @@ private Socket Connect(byte[] addressBytes, int port) case 0x02: // Send version - SocketWriteByte(socket, 0x01); + _ = socket.Send([0x01]); var username = Encoding.ASCII.GetBytes(_userName); if (username.Length > byte.MaxValue) @@ -87,10 +92,10 @@ private Socket Connect(byte[] addressBytes, int port) } // Send username length - SocketWriteByte(socket, (byte) username.Length); + _ = socket.Send([(byte) username.Length]); // Send username - SocketAbstraction.Send(socket, username); + _ = socket.Send(username); var password = Encoding.ASCII.GetBytes(_password); @@ -99,11 +104,11 @@ private Socket Connect(byte[] addressBytes, int port) throw new ProxyException("Proxy password is too long."); } - // Send username length - SocketWriteByte(socket, (byte) password.Length); + // Send password length + _ = socket.Send([(byte) password.Length]); - // Send username - SocketAbstraction.Send(socket, password); + // Send password + _ = socket.Send(password); var serverVersion = SocketReadByte(socket); @@ -126,20 +131,19 @@ private Socket Connect(byte[] addressBytes, int port) } // Send socks version number - SocketWriteByte(socket, 0x05); + _ = socket.Send([0x05]); // Send command code - SocketWriteByte(socket, 0x01); // establish a TCP/IP stream connection + _ = socket.Send([0x01]); // establish a TCP/IP stream connection // Send reserved, must be 0x00 - SocketWriteByte(socket, 0x00); + _ = socket.Send([0x00]); // Send address type and address - SocketAbstraction.Send(socket, addressBytes); + _ = socket.Send(addressBytes); // Send port - SocketWriteByte(socket, (byte)(port / 0xFF)); - SocketWriteByte(socket, (byte)(port % 0xFF)); + _ = socket.Send([(byte) (port / 0xFF), (byte) (port % 0xFF)]); // Read Server SOCKS5 version if (SocketReadByte(socket) != 5) @@ -226,11 +230,6 @@ private static byte[] GetAddressBytes(IPEndPoint endPoint) throw new ProxyException(string.Format("SOCKS5: IP address '{0}' is not supported.", endPoint.Address)); } - private static void SocketWriteByte(Socket socket, byte data) - { - SocketAbstraction.Send(socket, new[] { data }); - } - private static byte SocketReadByte(Socket socket) { var buffer = new byte[1]; diff --git a/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs b/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs index 393eae093..df5fc4a95 100644 --- a/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs +++ b/test/Renci.SshNet.Tests/Classes/Channels/ChannelForwardedTcpipTest_Dispose_SessionIsConnectedAndChannelIsOpen.cs @@ -9,6 +9,7 @@ using Moq; using Renci.SshNet.Channels; +using Renci.SshNet.Connection; using Renci.SshNet.Messages.Connection; using Renci.SshNet.Tests.Common; @@ -140,6 +141,7 @@ private void Arrange() _remoteListener.Start(); _channel = new ChannelForwardedTcpip(_sessionMock.Object, + new SocketFactory(), _localChannelNumber, _localWindowSize, _localPacketSize, diff --git a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs index 6808bfb50..22af7f447 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingHttpContent.cs @@ -119,7 +119,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs index 38f65634c..d16af22c0 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_TimeoutReadingStatusLine.cs @@ -95,7 +95,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs b/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs index 3710e2064..a3f1b0e9b 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/ProtocolVersionExchangeTest_TimeoutReadingIdentificationString.cs @@ -93,8 +93,8 @@ protected void Act() [TestMethod] public void StartShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNotNull(_actualException); - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format("Socket read operation has timed out after {0} milliseconds.", _timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs index d87969ced..3604aac72 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingDestinationAddress.cs @@ -97,7 +97,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs index 8f6ee9019..19a4b777b 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyCode.cs @@ -93,7 +93,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); } diff --git a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs index 4ca6e0c58..0e5ace5e5 100644 --- a/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs +++ b/test/Renci.SshNet.Tests/Classes/Connection/Socks4ConnectorTest_Connect_TimeoutReadingReplyVersion.cs @@ -82,7 +82,8 @@ protected override void Act() [TestMethod] public void ConnectShouldHaveThrownSshOperationTimeoutException() { - Assert.IsNull(_actualException.InnerException); + Assert.IsInstanceOfType(_actualException); + Assert.IsInstanceOfType(_actualException.InnerException); Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", _connectionInfo.Timeout.TotalMilliseconds), _actualException.Message); }