Skip to content

Commit

Permalink
Add async support to SftpClient and SftpFileStream (#819)
Browse files Browse the repository at this point in the history
* Add FEATURE_TAP and net472 target
* Add TAP async support to SftpClient and SftpFileStream
* Add async support to DnsAbstraction and SocketAbstraction
* Add async support to *Connector and refactor the hierarchy
* Add ConnectAsync to BaseClient
  • Loading branch information
IgorMilavec authored Dec 14, 2021
1 parent 30d79c7 commit 7bdfc9e
Show file tree
Hide file tree
Showing 21 changed files with 1,611 additions and 70 deletions.
22 changes: 22 additions & 0 deletions src/Renci.SshNet/Abstractions/DnsAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
using System.Net;
using System.Net.Sockets;

#if FEATURE_TAP
using System.Threading.Tasks;
#endif

#if FEATURE_DNS_SYNC
#elif FEATURE_DNS_APM
using Renci.SshNet.Common;
Expand Down Expand Up @@ -87,5 +91,23 @@ public static IPAddress[] GetHostAddresses(string hostNameOrAddress)
#endif // FEATURE_DEVICEINFORMATION_APM
#endif
}

#if FEATURE_TAP
/// <summary>
/// Returns the Internet Protocol (IP) addresses for the specified host.
/// </summary>
/// <param name="hostNameOrAddress">The host name or IP address to resolve</param>
/// <returns>
/// A task with result of an array of type <see cref="IPAddress"/> that holds the IP addresses for the host that
/// is specified by the <paramref name="hostNameOrAddress"/> parameter.
/// </returns>
/// <exception cref="ArgumentNullException"><paramref name="hostNameOrAddress"/> is <c>null</c>.</exception>
/// <exception cref="SocketException">An error is encountered when resolving <paramref name="hostNameOrAddress"/>.</exception>
public static Task<IPAddress[]> GetHostAddressesAsync(string hostNameOrAddress)
{
return Dns.GetHostAddressesAsync(hostNameOrAddress);
}
#endif

}
}
17 changes: 17 additions & 0 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using System.Net;
using System.Net.Sockets;
using System.Threading;
#if FEATURE_TAP
using System.Threading.Tasks;
#endif
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;

Expand Down Expand Up @@ -59,6 +62,13 @@ public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan co
ConnectCore(socket, remoteEndpoint, connectTimeout, false);
}

#if FEATURE_TAP
public static Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
return socket.ConnectAsync(remoteEndpoint, cancellationToken);
}
#endif

private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
{
#if FEATURE_SOCKET_EAP
Expand Down Expand Up @@ -317,6 +327,13 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout)
return buffer;
}

#if FEATURE_TAP
public static Task<int> ReadAsync(Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, offset, length, cancellationToken);
}
#endif

/// <summary>
/// Receives data from a bound <see cref="Socket"/> into a receive buffer.
/// </summary>
Expand Down
119 changes: 119 additions & 0 deletions src/Renci.SshNet/Abstractions/SocketExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#if FEATURE_TAP
using System;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace Renci.SshNet.Abstractions
{
// Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/

internal static class SocketExtensions
{
sealed class SocketAsyncEventArgsAwaitable : SocketAsyncEventArgs, INotifyCompletion
{
private readonly static Action SENTINEL = () => { };

private bool isCancelled;
private Action continuationAction;

public SocketAsyncEventArgsAwaitable()
{
Completed += delegate { SetCompleted(); };
}

public SocketAsyncEventArgsAwaitable ExecuteAsync(Func<SocketAsyncEventArgs, bool> func)
{
if (!func(this))
{
SetCompleted();
}
return this;
}

public void SetCompleted()
{
IsCompleted = true;
var continuation = continuationAction ?? Interlocked.CompareExchange(ref continuationAction, SENTINEL, null);
if (continuation != null)
{
continuation();
}
}

public void SetCancelled()
{
isCancelled = true;
SetCompleted();
}

public SocketAsyncEventArgsAwaitable GetAwaiter() { return this; }

public bool IsCompleted { get; private set; }

void INotifyCompletion.OnCompleted(Action continuation)
{
if (continuationAction == SENTINEL || Interlocked.CompareExchange(ref continuationAction, continuation, null) == SENTINEL)
{
// We have already completed; run continuation asynchronously
Task.Run(continuation);
}
}

public void GetResult()
{
if (isCancelled)
{
throw new TaskCanceledException();
}
else if (IsCompleted)
{
if (SocketError != SocketError.Success)
{
throw new SocketException((int)SocketError);
}
}
else
{
// We don't support sync/async
throw new InvalidOperationException("The asynchronous operation has not yet completed.");
}
}
}

public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using (var args = new SocketAsyncEventArgsAwaitable())
{
args.RemoteEndPoint = remoteEndpoint;

using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
{
await args.ExecuteAsync(socket.ConnectAsync);
}
}
}

public static async Task<int> ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using (var args = new SocketAsyncEventArgsAwaitable())
{
args.SetBuffer(buffer, offset, length);

using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
{
await args.ExecuteAsync(socket.ReceiveAsync);
}

return args.BytesTransferred;
}
}
}
}
#endif
80 changes: 80 additions & 0 deletions src/Renci.SshNet/BaseClient.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using System.Net.Sockets;
using System.Threading;
#if FEATURE_TAP
using System.Threading.Tasks;
#endif
using Renci.SshNet.Abstractions;
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
Expand Down Expand Up @@ -239,6 +242,63 @@ public void Connect()
StartKeepAliveTimer();
}

#if FEATURE_TAP
/// <summary>
/// Asynchronously connects client to the server.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous connect operation.
/// </returns>
/// <exception cref="InvalidOperationException">The client is already connected.</exception>
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
/// <exception cref="SocketException">Socket connection to the SSH server or proxy server could not be established, or an error occurred while resolving the hostname.</exception>
/// <exception cref="SshConnectionException">SSH session could not be established.</exception>
/// <exception cref="SshAuthenticationException">Authentication of SSH session failed.</exception>
/// <exception cref="ProxyException">Failed to establish proxy connection.</exception>
public async Task ConnectAsync(CancellationToken cancellationToken)
{
CheckDisposed();
cancellationToken.ThrowIfCancellationRequested();

// TODO (see issue #1758):
// we're not stopping the keep-alive timer and disposing the session here
//
// we could do this but there would still be side effects as concrete
// implementations may still hang on to the original session
//
// therefore it would be better to actually invoke the Disconnect method
// (and then the Dispose on the session) but even that would have side effects
// eg. it would remove all forwarded ports from SshClient
//
// I think we should modify our concrete clients to better deal with a
// disconnect. In case of SshClient this would mean not removing the
// forwarded ports on disconnect (but only on dispose ?) and link a
// forwarded port with a client instead of with a session
//
// To be discussed with Oleg (or whoever is interested)
if (IsSessionConnected())
throw new InvalidOperationException("The client is already connected.");

OnConnecting();

Session = await CreateAndConnectSessionAsync(cancellationToken).ConfigureAwait(false);
try
{
// Even though the method we invoke makes you believe otherwise, at this point only
// the SSH session itself is connected.
OnConnected();
}
catch
{
// Only dispose the session as Disconnect() would have side-effects (such as remove forwarded
// ports in SshClient).
DisposeSession();
throw;
}
StartKeepAliveTimer();
}
#endif

/// <summary>
/// Disconnects client from the server.
/// </summary>
Expand Down Expand Up @@ -473,6 +533,26 @@ private ISession CreateAndConnectSession()
}
}

#if FEATURE_TAP
private async Task<ISession> CreateAndConnectSessionAsync(CancellationToken cancellationToken)
{
var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
session.HostKeyReceived += Session_HostKeyReceived;
session.ErrorOccured += Session_ErrorOccured;

try
{
await session.ConnectAsync(cancellationToken).ConfigureAwait(false);
return session;
}
catch
{
DisposeSession(session);
throw;
}
}
#endif

private void DisposeSession(ISession session)
{
session.ErrorOccured -= Session_ErrorOccured;
Expand Down
45 changes: 45 additions & 0 deletions src/Renci.SshNet/Connection/ConnectorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;

#if FEATURE_TAP
using System.Threading.Tasks;
#endif

namespace Renci.SshNet.Connection
{
Expand All @@ -21,6 +26,10 @@ protected ConnectorBase(ISocketFactory socketFactory)

public abstract Socket Connect(IConnectionInfo connectionInfo);

#if FEATURE_TAP
public abstract Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken);
#endif

/// <summary>
/// Establishes a socket connection to the specified host and port.
/// </summary>
Expand Down Expand Up @@ -54,6 +63,42 @@ protected Socket SocketConnect(string host, int port, TimeSpan timeout)
}
}

#if FEATURE_TAP
/// <summary>
/// Establishes a socket connection to the specified host and port.
/// </summary>
/// <param name="host">The host name of the server to connect to.</param>
/// <param name="port">The port to connect to.</param>
/// <param name="cancellationToken">The cancellation token to observe.</param>
/// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
/// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
protected async Task<Socket> SocketConnectAsync(string host, int port, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

var ipAddress = (await DnsAbstraction.GetHostAddressesAsync(host).ConfigureAwait(false))[0];
var ep = new IPEndPoint(ipAddress, port);

DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port));

var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false);

const int socketBufferSize = 2 * Session.MaximumSshPacketSize;
socket.SendBufferSize = socketBufferSize;
socket.ReceiveBufferSize = socketBufferSize;
return socket;
}
catch (Exception)
{
socket.Dispose();
throw;
}
}
#endif

protected static byte SocketReadByte(Socket socket)
{
var buffer = new byte[1];
Expand Down
10 changes: 9 additions & 1 deletion src/Renci.SshNet/Connection/DirectConnector.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using System.Net.Sockets;
using System.Threading;

namespace Renci.SshNet.Connection
{
internal class DirectConnector : ConnectorBase
internal sealed class DirectConnector : ConnectorBase
{
public DirectConnector(ISocketFactory socketFactory) : base(socketFactory)
{
Expand All @@ -12,5 +13,12 @@ public override Socket Connect(IConnectionInfo connectionInfo)
{
return SocketConnect(connectionInfo.Host, connectionInfo.Port, connectionInfo.Timeout);
}

#if FEATURE_TAP
public override System.Threading.Tasks.Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken)
{
return SocketConnectAsync(connectionInfo.Host, connectionInfo.Port, cancellationToken);
}
#endif
}
}
Loading

0 comments on commit 7bdfc9e

Please sign in to comment.