diff --git a/.gitignore b/.gitignore index 7d6bdbf60..b3fc569f2 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ project.lock.json # Build outputs build/target/ + +# Rider Directory +.idea/ \ No newline at end of file diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index 287caea66..6d1ac7554 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -54,12 +54,21 @@ public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout) return socket; } +#if FEATURE_UNIX_SOCKETS + public static Socket Connect(UnixDomainSocketEndPoint remoteEndpoint, TimeSpan connectTimeout) + { + var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Unspecified); + ConnectCore(socket, remoteEndpoint, connectTimeout, true); + return socket; + } +#endif + public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { ConnectCore(socket, remoteEndpoint, connectTimeout, false); } - private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket) + private static void ConnectCore(Socket socket, EndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket) { #if FEATURE_SOCKET_EAP var connectCompleted = new ManualResetEvent(false); diff --git a/src/Renci.SshNet/Channels/ChannelDirectStreamLocal.cs b/src/Renci.SshNet/Channels/ChannelDirectStreamLocal.cs new file mode 100644 index 000000000..62db5dea3 --- /dev/null +++ b/src/Renci.SshNet/Channels/ChannelDirectStreamLocal.cs @@ -0,0 +1,309 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Common; +using Renci.SshNet.Messages.Connection; + +namespace Renci.SshNet.Channels +{ + /// + /// Implements "direct-streamlocal@openssh.com" SSH channel. + /// + internal class ChannelDirectStreamLocal : ClientChannel, IChannelDirectStreamLocal + { + private readonly object _socketLock = new object(); + + private EventWaitHandle _channelOpen = new AutoResetEvent(false); + private EventWaitHandle _channelData = new AutoResetEvent(false); + private IForwardedPort _forwardedPort; + private Socket _socket; + + /// + /// Initializes a new instance. + /// + /// The session. + /// The local channel number. + /// Size of the window. + /// Size of the packet. + public ChannelDirectStreamLocal(ISession session, uint localChannelNumber, uint localWindowSize, uint localPacketSize) + : base(session, localChannelNumber, localWindowSize, localPacketSize) + { + } + + /// + /// Gets the type of the channel. + /// + /// + /// The type of the channel. + /// + public override ChannelTypes ChannelType + { + get { return ChannelTypes.DirectStreamLocal; } + } + + public void Open(string remoteSocket, IForwardedPort forwardedPort, Socket socket) + { + if (IsOpen) + throw new SshException("Channel is already open."); + if (!IsConnected) + throw new SshException("Session is not connected."); + + lock (_socketLock) + { + _socket = socket; + } + _forwardedPort = forwardedPort; + _forwardedPort.Closing += ForwardedPort_Closing; + + var originatorAddress = ""; + var originatorPort = (uint)0; + + if (socket.RemoteEndPoint is IPEndPoint) + { + var ep = (IPEndPoint)socket.RemoteEndPoint; + originatorAddress = ep.Address.ToString(); + originatorPort = (uint) ep.Port; + } + + SendMessage(new ChannelOpenMessage(LocalChannelNumber, LocalWindowSize, LocalPacketSize, + new DirectStreamLocalChannelInfo(remoteSocket, originatorAddress, originatorPort))); + // Wait for channel to open + WaitOnHandle(_channelOpen); + } + + /// + /// Occurs as the forwarded port is being stopped. + /// + private void ForwardedPort_Closing(object sender, EventArgs eventArgs) + { + // signal to the client that we will not send anything anymore; this should also interrupt the + // blocking receive in Bind if the client sends FIN/ACK in time + ShutdownSocket(SocketShutdown.Send); + + // if the FIN/ACK is not sent in time by the remote client, then interrupt the blocking receive + // by closing the socket + CloseSocket(); + } + + /// + /// Binds channel to remote host. + /// + public void Bind() + { + // Cannot bind if channel is not open + if (!IsOpen) + return; + + var buffer = new byte[RemotePacketSize]; + + SocketAbstraction.ReadContinuous(_socket, buffer, 0, buffer.Length, SendData); + + // even though the client has disconnected, we still want to properly close the + // channel + // + // we'll do this in in Close() - invoked through Dispose(bool) - that way we have + // a single place from which we send an SSH_MSG_CHANNEL_EOF message and wait for + // the SSH_MSG_CHANNEL_CLOSE message + } + + /// + /// Closes the socket, hereby interrupting the blocking receive in . + /// + private void CloseSocket() + { + if (_socket == null) + return; + + lock (_socketLock) + { + if (_socket == null) + return; + + // closing a socket actually disposes the socket, so we can safely dereference + // the field to avoid entering the lock again later + _socket.Dispose(); + _socket = null; + } + } + + /// + /// Shuts down the socket. + /// + /// One of the values that specifies the operation that will no longer be allowed. + private void ShutdownSocket(SocketShutdown how) + { + if (_socket == null) + return; + + lock (_socketLock) + { + if (!_socket.IsConnected()) + return; + + try + { + _socket.Shutdown(how); + } + catch (SocketException ex) + { + // TODO: log as warning + DiagnosticAbstraction.Log("Failure shutting down socket: " + ex); + } + } + } + + /// + /// Closes the channel, waiting for the SSH_MSG_CHANNEL_CLOSE message to be received from the server. + /// + protected override void Close() + { + var forwardedPort = _forwardedPort; + if (forwardedPort != null) + { + forwardedPort.Closing -= ForwardedPort_Closing; + _forwardedPort = null; + } + + // signal to the client that we will not send anything anymore; this will also interrupt the + // blocking receive in Bind if the client sends FIN/ACK in time + // + // if the FIN/ACK is not sent in time, the socket will be closed after the channel is closed + ShutdownSocket(SocketShutdown.Send); + + // close the SSH channel + base.Close(); + + // close the socket + CloseSocket(); + } + + /// + /// Called when channel data is received. + /// + /// The data. + protected override void OnData(byte[] data) + { + base.OnData(data); + + if (_socket != null) + { + lock (_socketLock) + { + if (_socket.IsConnected()) + { + SocketAbstraction.Send(_socket, data, 0, data.Length); + } + } + } + } + + /// + /// Called when channel is opened by the server. + /// + /// The remote channel number. + /// Initial size of the window. + /// Maximum size of the packet. + protected override void OnOpenConfirmation(uint remoteChannelNumber, uint initialWindowSize, uint maximumPacketSize) + { + base.OnOpenConfirmation(remoteChannelNumber, initialWindowSize, maximumPacketSize); + + _channelOpen.Set(); + } + + protected override void OnOpenFailure(uint reasonCode, string description, string language) + { + base.OnOpenFailure(reasonCode, description, language); + + _channelOpen.Set(); + } + + /// + /// Called when channel has no more data to receive. + /// + protected override void OnEof() + { + base.OnEof(); + + // the channel will send no more data, and hence it does not make sense to receive + // any more data from the client to send to the remote party (and we surely won't + // send anything anymore) + // + // this will also interrupt the blocking receive in Bind() + ShutdownSocket(SocketShutdown.Send); + } + + /// + /// Called whenever an unhandled occurs in causing + /// the message loop to be interrupted, or when an exception occurred processing a channel message. + /// + protected override void OnErrorOccured(Exception exp) + { + base.OnErrorOccured(exp); + + // signal to the client that we will not send anything anymore; this will also interrupt the + // blocking receive in Bind if the client sends FIN/ACK in time + // + // if the FIN/ACK is not sent in time, the socket will be closed in Close(bool) + ShutdownSocket(SocketShutdown.Send); + } + + /// + /// Called when the server wants to terminate the connection immmediately. + /// + /// + /// The sender MUST NOT send or receive any data after this message, and + /// the recipient MUST NOT accept any data after receiving this message. + /// + protected override void OnDisconnected() + { + base.OnDisconnected(); + + // the channel will accept or send no more data, and hence it does not make sense + // to accept any more data from the client (and we surely won't send anything + // anymore) + // + // so lets signal to the client that we will not send or receive anything anymore + // this will also interrupt the blocking receive in Bind() + ShutdownSocket(SocketShutdown.Both); + } + + protected override void Dispose(bool disposing) + { + // make sure we've unsubscribed from all session events and closed the channel + // before we starting disposing + base.Dispose(disposing); + + if (disposing) + { + if (_socket != null) + { + lock (_socketLock) + { + var socket = _socket; + if (socket != null) + { + _socket = null; + socket.Dispose(); + } + } + } + + var channelOpen = _channelOpen; + if (channelOpen != null) + { + _channelOpen = null; + channelOpen.Dispose(); + } + + var channelData = _channelData; + if (channelData != null) + { + _channelData = null; + channelData.Dispose(); + } + } + } + } +} \ No newline at end of file diff --git a/src/Renci.SshNet/Channels/ChannelForwardedStreamLocal.cs b/src/Renci.SshNet/Channels/ChannelForwardedStreamLocal.cs new file mode 100644 index 000000000..a32d294ee --- /dev/null +++ b/src/Renci.SshNet/Channels/ChannelForwardedStreamLocal.cs @@ -0,0 +1,213 @@ +using System; +using System.Net; +using System.Net.Sockets; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Common; +using Renci.SshNet.Messages.Connection; + +namespace Renci.SshNet.Channels +{ + /// + /// Implements "forwarded-streamlocal@openssh.com" SSH channel. + /// + internal class ChannelForwardedStreamLocal : ServerChannel, IChannelForwardedStreamLocal + { + private readonly object _socketShutdownAndCloseLock = new object(); + private Socket _socket; + private IForwardedPort _forwardedPort; + + /// + /// Initializes a new instance. + /// + /// The session. + /// The local channel number. + /// Size of the window. + /// Size of the packet. + /// The remote channel number. + /// The window size of the remote party. + /// The maximum size of a data packet that we can send to the remote party. + internal ChannelForwardedStreamLocal(ISession session, + uint localChannelNumber, + uint localWindowSize, + uint localPacketSize, + uint remoteChannelNumber, + uint remoteWindowSize, + uint remotePacketSize) + : base(session, + localChannelNumber, + localWindowSize, + localPacketSize, + remoteChannelNumber, + remoteWindowSize, + remotePacketSize) + { + } + + /// + /// Gets the type of the channel. + /// + /// + /// The type of the channel. + /// + public override ChannelTypes ChannelType + { + get { return ChannelTypes.ForwardedStreamLocal; } + } + + /// + /// Binds the channel to the specified endpoint. + /// + /// The endpoint to connect to. + /// The forwarded port for which the channel is opened. + public void Bind(EndPoint remoteEndpoint, IForwardedPort forwardedPort) + { + if (!IsConnected) + { + throw new SshException("Session is not connected."); + } + + _forwardedPort = forwardedPort; + _forwardedPort.Closing += ForwardedPort_Closing; + + // Try to connect to the socket + try + { +#if FEATURE_UNIX_SOCKETS + if (remoteEndpoint is UnixDomainSocketEndPoint) + { + _socket = SocketAbstraction.Connect((UnixDomainSocketEndPoint)remoteEndpoint, ConnectionInfo.Timeout); + } + else + { +#endif + _socket = SocketAbstraction.Connect((IPEndPoint)remoteEndpoint, ConnectionInfo.Timeout); +#if FEATURE_UNIX_SOCKETS + } +#endif + + // send channel open confirmation message + SendMessage(new ChannelOpenConfirmationMessage(RemoteChannelNumber, LocalWindowSize, LocalPacketSize, LocalChannelNumber)); + } + catch (Exception exp) + { + // send channel open failure message + SendMessage(new ChannelOpenFailureMessage(RemoteChannelNumber, exp.ToString(), ChannelOpenFailureMessage.ConnectFailed, "en")); + + throw; + } + + var buffer = new byte[RemotePacketSize]; + + SocketAbstraction.ReadContinuous(_socket, buffer, 0, buffer.Length, SendData); + } + + protected override void OnErrorOccured(Exception exp) + { + base.OnErrorOccured(exp); + + // signal to the server that we will not send anything anymore; this will also interrupt the + // blocking receive in Bind if the server sends FIN/ACK in time + // + // if the FIN/ACK is not sent in time, the socket will be closed in Close(bool) + ShutdownSocket(SocketShutdown.Send); + } + + /// + /// Occurs as the forwarded port is being stopped. + /// + private void ForwardedPort_Closing(object sender, EventArgs eventArgs) + { + // signal to the server that we will not send anything anymore; this will also interrupt the + // blocking receive in Bind if the server sends FIN/ACK in time + // + // if the FIN/ACK is not sent in time, the socket will be closed in Close(bool) + ShutdownSocket(SocketShutdown.Send); + } + + /// + /// Shuts down the socket. + /// + /// One of the values that specifies the operation that will no longer be allowed. + private void ShutdownSocket(SocketShutdown how) + { + if (_socket == null) + return; + + lock (_socketShutdownAndCloseLock) + { + var socket = _socket; + if (!socket.IsConnected()) + return; + + try + { + socket.Shutdown(how); + } + catch (SocketException ex) + { + // TODO: log as warning + DiagnosticAbstraction.Log("Failure shutting down socket: " + ex); + } + } + } + + /// + /// Closes the socket, hereby interrupting the blocking receive in . + /// + private void CloseSocket() + { + if (_socket == null) + return; + + lock (_socketShutdownAndCloseLock) + { + var socket = _socket; + if (socket != null) + { + _socket = null; + socket.Dispose(); + } + } + } + + /// + /// Closes the channel waiting for the SSH_MSG_CHANNEL_CLOSE message to be received from the server. + /// + protected override void Close() + { + var forwardedPort = _forwardedPort; + if (forwardedPort != null) + { + forwardedPort.Closing -= ForwardedPort_Closing; + _forwardedPort = null; + } + + // signal to the server that we will not send anything anymore; this will also interrupt the + // blocking receive in Bind if the server sends FIN/ACK in time + // + // if the FIN/ACK is not sent in time, the socket will be closed after the channel is closed + ShutdownSocket(SocketShutdown.Send); + + // close the SSH channel, and mark the channel closed + base.Close(); + + // close the socket + CloseSocket(); + } + + /// + /// Called when channel data is received. + /// + /// The data. + protected override void OnData(byte[] data) + { + base.OnData(data); + + var socket = _socket; + if (socket.IsConnected()) + { + SocketAbstraction.Send(socket, data, 0, data.Length); + } + } + } +} \ No newline at end of file diff --git a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs index 7d731b9e5..f26503b28 100644 --- a/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs +++ b/src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs @@ -59,7 +59,7 @@ public override ChannelTypes ChannelType /// /// The endpoint to connect to. /// The forwarded port for which the channel is opened. - public void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort) + public void Bind(EndPoint remoteEndpoint, IForwardedPort forwardedPort) { if (!IsConnected) { @@ -72,7 +72,18 @@ public void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort) // Try to connect to the socket try { - _socket = SocketAbstraction.Connect(remoteEndpoint, ConnectionInfo.Timeout); +#if FEATURE_UNIX_SOCKETS + if (remoteEndpoint is UnixDomainSocketEndPoint) + { + _socket = SocketAbstraction.Connect((UnixDomainSocketEndPoint) remoteEndpoint, ConnectionInfo.Timeout); + } + else + { +#endif + _socket = SocketAbstraction.Connect((IPEndPoint)remoteEndpoint, ConnectionInfo.Timeout); +#if FEATURE_UNIX_SOCKETS + } +#endif // send channel open confirmation message SendMessage(new ChannelOpenConfirmationMessage(RemoteChannelNumber, LocalWindowSize, LocalPacketSize, LocalChannelNumber)); @@ -141,7 +152,7 @@ private void ShutdownSocket(SocketShutdown how) } /// - /// Closes the socket, hereby interrupting the blocking receive in . + /// Closes the socket, hereby interrupting the blocking receive in . /// private void CloseSocket() { diff --git a/src/Renci.SshNet/Channels/ChannelTypes.cs b/src/Renci.SshNet/Channels/ChannelTypes.cs index 230420232..f9b84aa65 100644 --- a/src/Renci.SshNet/Channels/ChannelTypes.cs +++ b/src/Renci.SshNet/Channels/ChannelTypes.cs @@ -21,6 +21,14 @@ internal enum ChannelTypes /// /// direct-tcpip /// - DirectTcpip + DirectTcpip, + /// + /// forwarded-streamlocal@openssh.com + /// + ForwardedStreamLocal, + /// + /// direct-streamlocal@openssh.com + /// + DirectStreamLocal, } } diff --git a/src/Renci.SshNet/Channels/IChannelDirectStreamLocal.cs b/src/Renci.SshNet/Channels/IChannelDirectStreamLocal.cs new file mode 100644 index 000000000..bf3082fdf --- /dev/null +++ b/src/Renci.SshNet/Channels/IChannelDirectStreamLocal.cs @@ -0,0 +1,46 @@ +using System; +using System.Net.Sockets; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Channels +{ + /// + /// A "direct-streamlocal@openssh.com" SSH channel. + /// + internal interface IChannelDirectStreamLocal : IDisposable + { + /// + /// Occurs when an exception is thrown while processing channel messages. + /// + event EventHandler Exception; + + /// + /// Gets a value indicating whether this channel is open. + /// + /// + /// true if this channel is open; otherwise, false. + /// + bool IsOpen { get; } + + /// + /// Gets the local channel number. + /// + /// + /// The local channel number. + /// + uint LocalChannelNumber { get; } + + /// + /// Opens a channel for a locally forwarded TCP/IP port. + /// + /// The remote unix socket to forward to. + /// The forwarded port for which the channel is opened. + /// The socket to receive requests from, and send responses from the remote host to. + void Open(string remoteSocket, IForwardedPort forwardedPort, Socket socket); + + /// + /// Binds the channel to the remote host. + /// + void Bind(); + } +} \ No newline at end of file diff --git a/src/Renci.SshNet/Channels/IChannelForwardedStreamLocal.cs b/src/Renci.SshNet/Channels/IChannelForwardedStreamLocal.cs new file mode 100644 index 000000000..76176bf4b --- /dev/null +++ b/src/Renci.SshNet/Channels/IChannelForwardedStreamLocal.cs @@ -0,0 +1,25 @@ +using System; +using System.Net; +using System.Net.Sockets; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Channels +{ + /// + /// A "forwarded-streamlocal@openssh.com" SSH channel. + /// + internal interface IChannelForwardedStreamLocal : IDisposable + { + /// + /// Occurs when an exception is thrown while processing channel messages. + /// + event EventHandler Exception; + + /// + /// Binds the channel to the specified endpoint. + /// + /// The socketPath to connect to. + /// The forwarded port for which the channel is opened. + void Bind(EndPoint remoteEndpoint, IForwardedPort forwardedPort); + } +} diff --git a/src/Renci.SshNet/Channels/IChannelForwardedTcpip.cs b/src/Renci.SshNet/Channels/IChannelForwardedTcpip.cs index 7bc165a8d..69368aa8b 100644 --- a/src/Renci.SshNet/Channels/IChannelForwardedTcpip.cs +++ b/src/Renci.SshNet/Channels/IChannelForwardedTcpip.cs @@ -19,6 +19,6 @@ internal interface IChannelForwardedTcpip : IDisposable /// /// The endpoint to connect to. /// The forwarded port for which the channel is opened. - void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort); + void Bind(EndPoint remoteEndpoint, IForwardedPort forwardedPort); } } diff --git a/src/Renci.SshNet/ForwardedPortLocal.NET.cs b/src/Renci.SshNet/ForwardedPortLocal.NET.cs index ba01ddbd3..048d48dfe 100644 --- a/src/Renci.SshNet/ForwardedPortLocal.NET.cs +++ b/src/Renci.SshNet/ForwardedPortLocal.NET.cs @@ -14,15 +14,28 @@ public partial class ForwardedPortLocal partial void InternalStart() { - var addr = DnsAbstraction.GetHostAddresses(BoundHost)[0]; - var ep = new IPEndPoint(addr, (int) BoundPort); +#if FEATURE_UNIX_SOCKETS + if (LocalUnixSocket != null) + { + _listener = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + _listener.Bind(LocalUnixSocket); + _listener.Listen(5); + } + else + { +#endif + var addr = DnsAbstraction.GetHostAddresses(BoundHost)[0]; + var ep = new IPEndPoint(addr, (int) BoundPort); - _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {NoDelay = true}; - _listener.Bind(ep); - _listener.Listen(5); + _listener = new Socket(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {NoDelay = true}; + _listener.Bind(ep); + _listener.Listen(5); - // update bound port (in case original was passed as zero) - BoundPort = (uint)((IPEndPoint)_listener.LocalEndPoint).Port; + // update bound port (in case original was passed as zero) + BoundPort = (uint) ((IPEndPoint) _listener.LocalEndPoint).Port; +#if FEATURE_UNIX_SOCKETS + } +#endif Session.ErrorOccured += Session_ErrorOccured; Session.Disconnected += Session_Disconnected; @@ -115,17 +128,44 @@ private void ProcessAccept(Socket clientSocket) try { - var originatorEndPoint = (IPEndPoint) clientSocket.RemoteEndPoint; - - RaiseRequestReceived(originatorEndPoint.Address.ToString(), - (uint)originatorEndPoint.Port); + var originatorAddress = ""; + var originatorPort = (uint)0; + if (clientSocket.RemoteEndPoint is IPEndPoint) + { + var originatorEndPoint = (IPEndPoint)clientSocket.RemoteEndPoint; + originatorAddress = originatorEndPoint.Address.ToString(); + originatorPort = (uint) originatorEndPoint.Port; + } +#if FEATURE_UNIX_SOCKETS + if (clientSocket.RemoteEndPoint is UnixDomainSocketEndPoint) + { + originatorAddress = ((UnixDomainSocketEndPoint)clientSocket.RemoteEndPoint).ToString(); + } +#endif + RaiseRequestReceived(originatorAddress, originatorPort); - using (var channel = Session.CreateChannelDirectTcpip()) +#if FEATURE_UNIX_SOCKETS + if (RemoteUnixSocket != null) + { + using (var channel = Session.CreateChannelDirectStreamLocal()) + { + channel.Exception += Channel_Exception; + channel.Open(RemoteUnixSocket.ToString(), this, clientSocket); + channel.Bind(); + } + } + else { - channel.Exception += Channel_Exception; - channel.Open(Host, Port, this, clientSocket); - channel.Bind(); +#endif + using (var channel = Session.CreateChannelDirectTcpip()) + { + channel.Exception += Channel_Exception; + channel.Open(Host, Port, this, clientSocket); + channel.Bind(); + } +#if FEATURE_UNIX_SOCKETS } +#endif } catch (Exception exp) { diff --git a/src/Renci.SshNet/ForwardedPortLocal.cs b/src/Renci.SshNet/ForwardedPortLocal.cs index 79246bd73..71edff224 100644 --- a/src/Renci.SshNet/ForwardedPortLocal.cs +++ b/src/Renci.SshNet/ForwardedPortLocal.cs @@ -1,4 +1,7 @@ using System; +#if FEATURE_UNIX_SOCKETS +using System.Net.Sockets; +#endif using Renci.SshNet.Common; namespace Renci.SshNet @@ -101,6 +104,83 @@ public ForwardedPortLocal(string boundHost, uint boundPort, string host, uint po _status = ForwardedPortStatus.Stopped; } +#if FEATURE_UNIX_SOCKETS + /// + /// Gets the bound unix socket. + /// + public UnixDomainSocketEndPoint LocalUnixSocket { get; private set; } + + /// + /// Gets the forwarded unix socket. + /// + public UnixDomainSocketEndPoint RemoteUnixSocket { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + public ForwardedPortLocal(UnixDomainSocketEndPoint localSocket, UnixDomainSocketEndPoint remoteSocket) + { + if (localSocket == null) + throw new ArgumentNullException("localSocket"); + + if (remoteSocket == null) + throw new ArgumentNullException("remoteSocket"); + + LocalUnixSocket = localSocket; + RemoteUnixSocket = remoteSocket; + _status = ForwardedPortStatus.Stopped; + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + public ForwardedPortLocal(UnixDomainSocketEndPoint localSocket, string host, uint port) + { + if (localSocket == null) + throw new ArgumentNullException("localSocket"); + + if (host == null) + throw new ArgumentNullException("host"); + + port.ValidatePort("port"); + + LocalUnixSocket = localSocket; + Host = host; + Port = port; + _status = ForwardedPortStatus.Stopped; + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + public ForwardedPortLocal(string boundHost, uint boundPort, UnixDomainSocketEndPoint remoteSocket) + { + if (boundHost == null) + throw new ArgumentNullException("boundHost"); + + if (remoteSocket == null) + throw new ArgumentNullException("remoteSocket"); + + boundPort.ValidatePort("boundPort"); + + RemoteUnixSocket = remoteSocket; + BoundHost = boundHost; + BoundPort = boundPort; + _status = ForwardedPortStatus.Stopped; + } +#endif + /// /// Starts local port forwarding. /// diff --git a/src/Renci.SshNet/ForwardedPortRemote.cs b/src/Renci.SshNet/ForwardedPortRemote.cs index b2ed15fe4..bb78b6a40 100644 --- a/src/Renci.SshNet/ForwardedPortRemote.cs +++ b/src/Renci.SshNet/ForwardedPortRemote.cs @@ -4,6 +4,9 @@ using Renci.SshNet.Common; using System.Globalization; using System.Net; +#if FEATURE_UNIX_SOCKETS +using System.Net.Sockets; +#endif using Renci.SshNet.Abstractions; namespace Renci.SshNet @@ -129,6 +132,83 @@ public ForwardedPortRemote(string boundHost, uint boundPort, string host, uint p { } +#if FEATURE_UNIX_SOCKETS + /// + /// Gets the bound unix socket. + /// + public UnixDomainSocketEndPoint LocalUnixSocket { get; private set; } + + /// + /// Gets the forwarded unix socket. + /// + public UnixDomainSocketEndPoint RemoteUnixSocket { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + public ForwardedPortRemote(UnixDomainSocketEndPoint remoteSocket, UnixDomainSocketEndPoint localSocket) + { + if (remoteSocket == null) + throw new ArgumentNullException("remoteSocket"); + + if (localSocket == null) + throw new ArgumentNullException("localSocket"); + + RemoteUnixSocket = remoteSocket; + LocalUnixSocket = localSocket; + _status = ForwardedPortStatus.Stopped; + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + public ForwardedPortRemote(UnixDomainSocketEndPoint remoteSocket, IPAddress hostAddress, uint port) + { + if (remoteSocket == null) + throw new ArgumentNullException("remoteSocket"); + + if (hostAddress == null) + throw new ArgumentNullException("hostAddress"); + + port.ValidatePort("port"); + + RemoteUnixSocket = remoteSocket; + HostAddress = hostAddress; + Port = port; + _status = ForwardedPortStatus.Stopped; + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + public ForwardedPortRemote(IPAddress boundHostAddress, uint boundPort, UnixDomainSocketEndPoint localSocket) + { + if (boundHostAddress == null) + throw new ArgumentNullException("boundHostAddress"); + + if (localSocket == null) + throw new ArgumentNullException("localSocket"); + + boundPort.ValidatePort("boundPort"); + + LocalUnixSocket = localSocket; + BoundHostAddress = boundHostAddress; + BoundPort = boundPort; + _status = ForwardedPortStatus.Stopped; + } +#endif + /// /// Starts remote port forwarding. /// @@ -150,12 +230,27 @@ protected override void StartPort() Session.ChannelOpenReceived += Session_ChannelOpening; // send global request to start forwarding - Session.SendMessage(new TcpIpForwardGlobalRequestMessage(BoundHost, BoundPort)); +#if FEATURE_UNIX_SOCKETS + if (RemoteUnixSocket != null) + { + Session.SendMessage(new StreamLocalForwardGlobalRequestMessage(RemoteUnixSocket.ToString())); + } + else + { +#endif + Session.SendMessage(new TcpIpForwardGlobalRequestMessage(BoundHost, BoundPort)); +#if FEATURE_UNIX_SOCKETS + } +#endif // wat for response on global request to start direct tcpip Session.WaitOnHandle(_globalRequestResponse); if (!_requestStatus) { +#if FEATURE_UNIX_SOCKETS + if (RemoteUnixSocket != null) + throw new SshException(string.Format(CultureInfo.CurrentCulture, "Forwarding for '{0}' failed to start.", RemoteUnixSocket.ToString())); +#endif throw new SshException(string.Format(CultureInfo.CurrentCulture, "Port forwarding for '{0}' port '{1}' failed to start.", Host, Port)); } } @@ -188,7 +283,18 @@ protected override void StopPort(TimeSpan timeout) base.StopPort(timeout); // send global request to cancel direct tcpip - Session.SendMessage(new CancelTcpIpForwardGlobalRequestMessage(BoundHost, BoundPort)); +#if FEATURE_UNIX_SOCKETS + if (RemoteUnixSocket != null) + { + Session.SendMessage(new CancelStreamLocalForwardGlobalRequestMessage(RemoteUnixSocket.ToString())); + } + else + { +#endif + Session.SendMessage(new CancelTcpIpForwardGlobalRequestMessage(BoundHost, BoundPort)); +#if FEATURE_UNIX_SOCKETS + } +#endif // wait for response on global request to cancel direct tcpip or completion of message // listener loop (in which case response on global request can never be received) WaitHandle.WaitAny(new[] { _globalRequestResponse, Session.MessageListenerCompleted }, timeout); @@ -224,9 +330,10 @@ protected override void CheckDisposed() private void Session_ChannelOpening(object sender, MessageEventArgs e) { var channelOpenMessage = e.Message; - var info = channelOpenMessage.Info as ForwardedTcpipChannelInfo; - if (info != null) + + if (channelOpenMessage.Info is ForwardedTcpipChannelInfo) { + var info = (ForwardedTcpipChannelInfo) channelOpenMessage.Info; // Ensure this is the corresponding request if (info.ConnectedAddress == BoundHost && info.ConnectedPort == BoundPort) { @@ -252,7 +359,18 @@ private void Session_ChannelOpening(object sender, MessageEventArgs + { + // capture the countdown event that we're adding a count to, as we need to make sure that we'll be signaling + // that same instance; the instance field for the countdown event is re-initialize when the port is restarted + // and that time there may still be pending requests + var pendingChannelCountdown = _pendingChannelCountdown; + + pendingChannelCountdown.AddCount(); + + try + { + RaiseRequestReceived(streamChannelInfo.SocketPath, 0); + + using (var channel = Session.CreateChannelForwardedStreamLocal(channelOpenMessage.LocalChannelNumber, channelOpenMessage.InitialWindowSize, channelOpenMessage.MaximumPacketSize)) + { + channel.Exception += Channel_Exception; + if (LocalUnixSocket != null) + { + channel.Bind(LocalUnixSocket, this); + } + else + { + channel.Bind(new IPEndPoint(HostAddress, (int) Port), this); + } + } + } + catch (Exception exp) + { + RaiseExceptionEvent(exp); + } + finally + { + // take into account that CountdownEvent has since been disposed; when stopping the port we + // wait for a given time for the channels to close, but once that timeout period has elapsed + // the CountdownEvent will be disposed + try + { + pendingChannelCountdown.Signal(); + } + catch (ObjectDisposedException) + { + } + } + }); + } + } +#endif } /// diff --git a/src/Renci.SshNet/ISession.cs b/src/Renci.SshNet/ISession.cs index cd950da52..0296fe355 100644 --- a/src/Renci.SshNet/ISession.cs +++ b/src/Renci.SshNet/ISession.cs @@ -70,6 +70,14 @@ internal interface ISession : IDisposable /// IChannelDirectTcpip CreateChannelDirectTcpip(); + /// + /// Create a new channel for a locally forwarded unix domain socket. + /// + /// + /// A new channel for a locally forwarded unix domain socket. + /// + IChannelDirectStreamLocal CreateChannelDirectStreamLocal(); + /// /// Creates a "forwarded-tcpip" SSH channel. /// @@ -78,6 +86,14 @@ internal interface ISession : IDisposable /// IChannelForwardedTcpip CreateChannelForwardedTcpip(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize); + /// + /// Creates a "forwarded-streamlocal@openssh.com" SSH channel. + /// + /// + /// A new "forwarded-streamlocal@openssh.com" SSH channel. + /// + IChannelForwardedStreamLocal CreateChannelForwardedStreamLocal(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize); + /// /// Disconnects from the server. /// diff --git a/src/Renci.SshNet/Messages/Connection/CancelStreamLocalForwardGlobalRequestMessage.cs b/src/Renci.SshNet/Messages/Connection/CancelStreamLocalForwardGlobalRequestMessage.cs new file mode 100644 index 000000000..cce164b47 --- /dev/null +++ b/src/Renci.SshNet/Messages/Connection/CancelStreamLocalForwardGlobalRequestMessage.cs @@ -0,0 +1,58 @@ +using System; + +namespace Renci.SshNet.Messages.Connection +{ + internal class CancelStreamLocalForwardGlobalRequestMessage : GlobalRequestMessage + { + private byte[] _socketPath; + + public CancelStreamLocalForwardGlobalRequestMessage(string socketPath) + : base(Ascii.GetBytes("cancel-streamlocal-forward@openssh.com"), true) + { + SocketPath = socketPath; + } + + /// + /// Gets the socket path to bind to. + /// + public string SocketPath + { + get { return Utf8.GetString(_socketPath, 0, _socketPath.Length); } + private set { _socketPath = Utf8.GetBytes(value); } + } + + /// + /// Gets the size of the message in bytes. + /// + /// + /// The size of the messages in bytes. + /// + protected override int BufferCapacity + { + get + { + var capacity = base.BufferCapacity; + capacity += 4; // AddressToBind length + capacity += _socketPath.Length; // AddressToBind + return capacity; + } + } + + /// + /// Called when type specific data need to be loaded. + /// + protected override void LoadData() + { + throw new NotImplementedException(); + } + + /// + /// Called when type specific data need to be saved. + /// + protected override void SaveData() + { + base.SaveData(); + WriteBinaryString(_socketPath); + } + } +} diff --git a/src/Renci.SshNet/Messages/Connection/ChannelOpen/ChannelOpenMessage.cs b/src/Renci.SshNet/Messages/Connection/ChannelOpen/ChannelOpenMessage.cs index 1f8bb3e92..5a09f74e0 100644 --- a/src/Renci.SshNet/Messages/Connection/ChannelOpen/ChannelOpenMessage.cs +++ b/src/Renci.SshNet/Messages/Connection/ChannelOpen/ChannelOpenMessage.cs @@ -127,6 +127,12 @@ protected override void LoadData() case ForwardedTcpipChannelInfo.NAME: Info = new ForwardedTcpipChannelInfo(_infoBytes); break; + case DirectStreamLocalChannelInfo.NAME: + Info = new DirectTcpipChannelInfo(_infoBytes); + break; + case ForwardedStreamChannelInfo.NAME: + Info = new ForwardedStreamChannelInfo(_infoBytes); + break; default: throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Channel type '{0}' is not supported.", channelName)); } diff --git a/src/Renci.SshNet/Messages/Connection/ChannelOpen/DirectStreamLocalChannelInfo.cs b/src/Renci.SshNet/Messages/Connection/ChannelOpen/DirectStreamLocalChannelInfo.cs new file mode 100644 index 000000000..0b82e81e3 --- /dev/null +++ b/src/Renci.SshNet/Messages/Connection/ChannelOpen/DirectStreamLocalChannelInfo.cs @@ -0,0 +1,119 @@ +using System; + +namespace Renci.SshNet.Messages.Connection +{ + /// + /// Used to open "direct-streamlocal@openssh.com"" channel type + /// + internal class DirectStreamLocalChannelInfo : ChannelOpenInfo + { + private byte[] _socketPath; + private byte[] _originatorAddress; + + /// + /// Specifies channel open type + /// + public const string NAME = "direct-streamlocal@openssh.com"; + + /// + /// Gets the type of the channel to open. + /// + /// + /// The type of the channel to open. + /// + public override string ChannelType + { + get { return NAME; } + } + + /// + /// Gets the host to connect. + /// + public string SocketPath + { + get { return Utf8.GetString(_socketPath, 0, _socketPath.Length); } + private set { _socketPath = Utf8.GetBytes(value); } + } + + /// + /// Gets the originator address. + /// + public string OriginatorAddress + { + get { return Utf8.GetString(_originatorAddress, 0, _originatorAddress.Length); } + private set { _originatorAddress = Utf8.GetBytes(value); } + } + + /// + /// Gets the originator port. + /// + public uint OriginatorPort { get; private set; } + + /// + /// Gets the size of the message in bytes. + /// + /// + /// The size of the messages in bytes. + /// + protected override int BufferCapacity + { + get + { + var capacity = base.BufferCapacity; + capacity += 4; // SocketPath length + capacity += _socketPath.Length; // SocketPath + capacity += 4; // OriginatorAddress length + capacity += _originatorAddress.Length; // OriginatorAddress + capacity += 4; // OriginatorPort + return capacity; + } + } + + /// + /// Initializes a new instance of the class from the + /// specified data. + /// + /// is null. + public DirectStreamLocalChannelInfo(byte[] data) + { + Load(data); + } + + /// + /// Initializes a new instance of the class + /// + /// + /// + /// + public DirectStreamLocalChannelInfo(string socketPath, string originatorAddress, uint originatorPort) + { + SocketPath = socketPath; + OriginatorAddress = originatorAddress; + OriginatorPort = originatorPort; + } + + /// + /// Called when type specific data need to be loaded. + /// + protected override void LoadData() + { + base.LoadData(); + + _socketPath = ReadBinary(); + _originatorAddress = ReadBinary(); + OriginatorPort = ReadUInt32(); + } + + /// + /// Called when type specific data need to be saved. + /// + protected override void SaveData() + { + base.SaveData(); + + WriteBinaryString(_socketPath); + WriteBinaryString(_originatorAddress); + Write(OriginatorPort); + } + } +} \ No newline at end of file diff --git a/src/Renci.SshNet/Messages/Connection/ChannelOpen/ForwardedStreamLocalChannelInfo.cs b/src/Renci.SshNet/Messages/Connection/ChannelOpen/ForwardedStreamLocalChannelInfo.cs new file mode 100644 index 000000000..6f1bae1d7 --- /dev/null +++ b/src/Renci.SshNet/Messages/Connection/ChannelOpen/ForwardedStreamLocalChannelInfo.cs @@ -0,0 +1,109 @@ +using System; + +namespace Renci.SshNet.Messages.Connection +{ + /// + /// Used to open "forwarded-streamlocal@openssh.com" channel type + /// + internal class ForwardedStreamChannelInfo : ChannelOpenInfo + { + private byte[] _socketPath; + private byte[] _reserved; + + /// + /// Specifies channel open type + /// + public const string NAME = "forwarded-streamlocal@openssh.com"; + + /// + /// Initializes a new instance of the class from the + /// specified data. + /// + /// is null. + public ForwardedStreamChannelInfo(byte[] data) + { + Load(data); + } + + /// + /// Initializes a new instance with the specified connector + /// address and port, and originator address and port. + /// + public ForwardedStreamChannelInfo(string socketPath, string reserved = "") + { + SocketPath = socketPath; + Reserved = reserved; + } + + /// + /// Gets the type of the channel to open. + /// + /// + /// The type of the channel to open. + /// + public override string ChannelType + { + get { return NAME; } + } + + /// + /// Gets the connected address. + /// + public string SocketPath + { + get { return Utf8.GetString(_socketPath, 0, _socketPath.Length); } + private set { _socketPath = Utf8.GetBytes(value); } + } + + /// + /// Gets the originator address. + /// + public string Reserved + { + get { return Utf8.GetString(_reserved, 0, _reserved.Length); } + private set { _reserved = Utf8.GetBytes(value); } + } + + /// + /// Gets the size of the message in bytes. + /// + /// + /// The size of the messages in bytes. + /// + protected override int BufferCapacity + { + get + { + var capacity = base.BufferCapacity; + capacity += 4; // ConnectedAddress length + capacity += _socketPath.Length; // ConnectedAddress + capacity += 4; // ConnectedPort + capacity += 4; // Reserved length + capacity += _reserved.Length; // Reserved + return capacity; + } + } + + /// + /// Called when type specific data need to be loaded. + /// + protected override void LoadData() + { + base.LoadData(); + + _socketPath = ReadBinary(); + _reserved = ReadBinary(); + } + + /// + /// Called when type specific data need to be saved. + /// + protected override void SaveData() + { + base.SaveData(); + + WriteBinaryString(_socketPath); + WriteBinaryString(_reserved); + } + } +} diff --git a/src/Renci.SshNet/Messages/Connection/StreamLocalForwardGlobalRequestMessage.cs b/src/Renci.SshNet/Messages/Connection/StreamLocalForwardGlobalRequestMessage.cs new file mode 100644 index 000000000..f633c041b --- /dev/null +++ b/src/Renci.SshNet/Messages/Connection/StreamLocalForwardGlobalRequestMessage.cs @@ -0,0 +1,58 @@ +using System; + +namespace Renci.SshNet.Messages.Connection +{ + internal class StreamLocalForwardGlobalRequestMessage : GlobalRequestMessage + { + private byte[] _socketPath; + + public StreamLocalForwardGlobalRequestMessage(string socketpath) + : base(Ascii.GetBytes("streamlocal-forward@openssh.com"), true) + { + SocketPath = socketpath; + } + + /// + /// Gets the socket path to bind to. + /// + public string SocketPath + { + get { return Utf8.GetString(_socketPath, 0, _socketPath.Length); } + private set { _socketPath = Utf8.GetBytes(value); } + } + + /// + /// Gets the size of the message in bytes. + /// + /// + /// The size of the messages in bytes. + /// + protected override int BufferCapacity + { + get + { + var capacity = base.BufferCapacity; + capacity += 4; // AddressToBind length + capacity += _socketPath.Length; // AddressToBind + return capacity; + } + } + + /// + /// Called when type specific data need to be loaded. + /// + protected override void LoadData() + { + throw new NotImplementedException(); + } + + /// + /// Called when type specific data need to be saved. + /// + protected override void SaveData() + { + base.SaveData(); + WriteBinaryString(_socketPath); + } + } +} diff --git a/src/Renci.SshNet/PrivateKeyFile.cs b/src/Renci.SshNet/PrivateKeyFile.cs index 1424134f6..2195ac021 100644 --- a/src/Renci.SshNet/PrivateKeyFile.cs +++ b/src/Renci.SshNet/PrivateKeyFile.cs @@ -79,6 +79,15 @@ public class PrivateKeyFile : IDisposable /// public HostAlgorithm HostKey { get; private set; } + /// + /// Initializes a new instance of the class. + /// + /// The key. + public PrivateKeyFile(Key key) + { + HostKey = new KeyHostAlgorithm(key.ToString(), key); + } + /// /// Initializes a new instance of the class. /// diff --git a/src/Renci.SshNet/Renci.SshNet.csproj b/src/Renci.SshNet/Renci.SshNet.csproj index 124ce9d4b..dd15d7d43 100644 --- a/src/Renci.SshNet/Renci.SshNet.csproj +++ b/src/Renci.SshNet/Renci.SshNet.csproj @@ -7,7 +7,7 @@ ../Renci.SshNet.snk 5 true - net35;net40;netstandard1.3;netstandard2.0 + net35;net40;netstandard1.3;netstandard2.0;netstandard2.1