From 3422faa843416850cb8e4e66845b3bc24e0e1115 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Sun, 11 May 2025 18:39:21 +0200 Subject: [PATCH 1/4] Fix deadlock in SftpClient.UploadFile upon error --- src/Renci.SshNet/Sftp/SftpSession.cs | 2 +- src/Renci.SshNet/SftpClient.cs | 87 ++++++++++++++++++---------- 2 files changed, 57 insertions(+), 32 deletions(-) diff --git a/src/Renci.SshNet/Sftp/SftpSession.cs b/src/Renci.SshNet/Sftp/SftpSession.cs index de83bd887..96c652de9 100644 --- a/src/Renci.SshNet/Sftp/SftpSession.cs +++ b/src/Renci.SshNet/Sftp/SftpSession.cs @@ -2272,7 +2272,7 @@ public uint CalculateOptimalWriteLength(uint bufferSize, byte[] handle) return Math.Min(bufferSize, maximumPacketSize) - lengthOfNonDataProtocolFields; } - private static SshException GetSftpException(SftpStatusResponse response) + internal static SshException GetSftpException(SftpStatusResponse response) { #pragma warning disable IDE0010 // Add missing cases switch (response.StatusCode) diff --git a/src/Renci.SshNet/SftpClient.cs b/src/Renci.SshNet/SftpClient.cs index 50a2e9cbd..8a7fab1a8 100644 --- a/src/Renci.SshNet/SftpClient.cs +++ b/src/Renci.SshNet/SftpClient.cs @@ -1,11 +1,13 @@ #nullable enable using System; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; using System.Net; using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -2456,56 +2458,79 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo // create buffer of optimal length var buffer = new byte[_sftpSession.CalculateOptimalWriteLength(_bufferSize, handle)]; - var bytesRead = input.Read(buffer, 0, buffer.Length); + int bytesRead; var expectedResponses = 0; - var responseReceivedWaitHandle = new AutoResetEvent(initialState: false); + using var mres = new ManualResetEventSlim(initialState: false); - do + ExceptionDispatchInfo? exception = null; + + while ((bytesRead = input.Read(buffer, 0, buffer.Length)) != 0) { - // Cancel upload if (asyncResult is not null && asyncResult.IsUploadCanceled) { break; } - if (bytesRead > 0) + exception?.Throw(); + + var writtenBytes = offset + (ulong)bytesRead; + + _ = Interlocked.Increment(ref expectedResponses); + mres.Reset(); + + _sftpSession.RequestWrite(handle, offset, buffer, offset: 0, bytesRead, wait: null, s => { - var writtenBytes = offset + (ulong)bytesRead; + var setHandle = false; - _sftpSession.RequestWrite(handle, offset, buffer, offset: 0, bytesRead, wait: null, s => + try + { + if (Interlocked.Decrement(ref expectedResponses) == 0) { - if (s.StatusCode == StatusCodes.Ok) - { - _ = Interlocked.Decrement(ref expectedResponses); - _ = responseReceivedWaitHandle.Set(); + setHandle = true; + } - asyncResult?.Update(writtenBytes); + if (Sftp.SftpSession.GetSftpException(s) is Exception ex) + { + exception = ExceptionDispatchInfo.Capture(ex); + } - // Call callback to report number of bytes written - if (uploadCallback is not null) - { - // Execute callback on different thread - ThreadAbstraction.ExecuteThread(() => uploadCallback(writtenBytes)); - } - } - }); + if (exception is not null) + { + setHandle = true; + return; + } - _ = Interlocked.Increment(ref expectedResponses); + Debug.Assert(s.StatusCode == StatusCodes.Ok); - offset += (ulong)bytesRead; + asyncResult?.Update(writtenBytes); - bytesRead = input.Read(buffer, 0, buffer.Length); - } - else if (expectedResponses > 0) - { - // Wait for expectedResponses to change - _sftpSession.WaitOnHandle(responseReceivedWaitHandle, _operationTimeout); - } + // Call callback to report number of bytes written + if (uploadCallback is not null) + { + // Execute callback on different thread + ThreadAbstraction.ExecuteThread(() => uploadCallback(writtenBytes)); + } + } + finally + { + if (setHandle) + { + mres.Set(); + } + } + }); + + offset += (ulong)bytesRead; } - while (expectedResponses > 0 || bytesRead > 0); + + if (expectedResponses != 0) + { + _sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout); + } + + exception?.Throw(); _sftpSession.RequestClose(handle); - responseReceivedWaitHandle.Dispose(); } private async Task InternalUploadFileAsync(Stream input, string path, CancellationToken cancellationToken) From c6ba5945ee7fb3df26387e01a940b7bd78c73f93 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Sun, 11 May 2025 18:40:07 +0200 Subject: [PATCH 2/4] Make RequestWrite deterministic wrt. exception handling --- src/Renci.SshNet/Sftp/SftpSession.cs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/Renci.SshNet/Sftp/SftpSession.cs b/src/Renci.SshNet/Sftp/SftpSession.cs index 96c652de9..3200d8257 100644 --- a/src/Renci.SshNet/Sftp/SftpSession.cs +++ b/src/Renci.SshNet/Sftp/SftpSession.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Globalization; using System.Text; using System.Threading; @@ -914,6 +915,8 @@ public void RequestWrite(byte[] handle, AutoResetEvent wait, Action writeCompleted = null) { + Debug.Assert((wait is null) != (writeCompleted is null), "Should have one parameter or the other."); + SshException exception = null; var request = new SftpWriteRequest(ProtocolVersion, @@ -925,10 +928,15 @@ public void RequestWrite(byte[] handle, length, response => { - writeCompleted?.Invoke(response); - - exception = GetSftpException(response); - wait?.SetIgnoringObjectDisposed(); + if (writeCompleted is not null) + { + writeCompleted.Invoke(response); + } + else + { + exception = GetSftpException(response); + wait.SetIgnoringObjectDisposed(); + } }); SendRequest(request); @@ -936,11 +944,11 @@ public void RequestWrite(byte[] handle, if (wait is not null) { WaitOnHandle(wait, OperationTimeout); - } - if (exception is not null) - { - throw exception; + if (exception is not null) + { + throw exception; + } } } From 58be453d9fd211046a07476cf78d31d916cb7399 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Sun, 11 May 2025 18:40:12 +0200 Subject: [PATCH 3/4] add regression test; fix race --- .../Sftp/Responses/SftpStatusResponse.cs | 9 +- src/Renci.SshNet/SftpClient.cs | 17 +- .../Classes/SftpClientTest.UploadFile.cs | 252 ++++++++++++++++++ 3 files changed, 270 insertions(+), 8 deletions(-) create mode 100644 test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs diff --git a/src/Renci.SshNet/Sftp/Responses/SftpStatusResponse.cs b/src/Renci.SshNet/Sftp/Responses/SftpStatusResponse.cs index 4e31154b6..07aaa3f64 100644 --- a/src/Renci.SshNet/Sftp/Responses/SftpStatusResponse.cs +++ b/src/Renci.SshNet/Sftp/Responses/SftpStatusResponse.cs @@ -12,7 +12,7 @@ public SftpStatusResponse(uint protocolVersion) { } - public StatusCodes StatusCode { get; private set; } + public StatusCodes StatusCode { get; set; } public string ErrorMessage { get; private set; } @@ -39,5 +39,12 @@ protected override void LoadData() Language = ReadString(Ascii); } } + + protected override void SaveData() + { + base.SaveData(); + + Write((uint)StatusCode); + } } } diff --git a/src/Renci.SshNet/SftpClient.cs b/src/Renci.SshNet/SftpClient.cs index 8a7fab1a8..3cc4562f3 100644 --- a/src/Renci.SshNet/SftpClient.cs +++ b/src/Renci.SshNet/SftpClient.cs @@ -2460,6 +2460,10 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo int bytesRead; var expectedResponses = 0; + + // We will send out all the write requests without waiting for each response. + // Afterwards, we may wait on this handle until all responses are received + // or an error has occured. using var mres = new ManualResetEventSlim(initialState: false); ExceptionDispatchInfo? exception = null; @@ -2484,11 +2488,6 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo try { - if (Interlocked.Decrement(ref expectedResponses) == 0) - { - setHandle = true; - } - if (Sftp.SftpSession.GetSftpException(s) is Exception ex) { exception = ExceptionDispatchInfo.Capture(ex); @@ -2513,7 +2512,7 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo } finally { - if (setHandle) + if (Interlocked.Decrement(ref expectedResponses) == 0 || setHandle) { mres.Set(); } @@ -2523,7 +2522,11 @@ private void InternalUploadFile(Stream input, string path, Flags flags, SftpUplo offset += (ulong)bytesRead; } - if (expectedResponses != 0) + // Make sure the read of exception cannot be executed ahead of + // the read of expectedResponses so that we do not miss an + // exception. + + if (Volatile.Read(ref expectedResponses) != 0) { _sftpSession.WaitOnHandle(mres.WaitHandle, _operationTimeout); } diff --git a/test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs b/test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs new file mode 100644 index 000000000..98c96ea83 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs @@ -0,0 +1,252 @@ +using System; +using System.IO; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Moq; + +using Renci.SshNet.Channels; +using Renci.SshNet.Common; +using Renci.SshNet.Connection; +using Renci.SshNet.Messages; +using Renci.SshNet.Messages.Authentication; +using Renci.SshNet.Messages.Connection; +using Renci.SshNet.Sftp; +using Renci.SshNet.Sftp.Responses; + +namespace Renci.SshNet.Tests.Classes +{ + public partial class SftpClientTest + { + [TestMethod] + public void UploadFile_ObservesErrorResponses() + { + // A regression test for UploadFile hanging instead of observing + // error responses from the server. + // https://github.com/sshnet/SSH.NET/issues/957 + + var serviceFactoryMock = new Mock(); + + var connInfo = new PasswordConnectionInfo("host", "user", "pwd"); + + var session = new MySession(connInfo); + + var concreteServiceFactory = new ServiceFactory(); + + serviceFactoryMock + .Setup(p => p.CreateSession(It.IsAny(), It.IsAny())) + .Returns(session); + + serviceFactoryMock + .Setup(p => p.CreateSftpResponseFactory()) + .Returns(concreteServiceFactory.CreateSftpResponseFactory); + + serviceFactoryMock + .Setup(p => p.CreateSftpSession(session, It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(concreteServiceFactory.CreateSftpSession); + + using var client = new SftpClient(connInfo, false, serviceFactoryMock.Object); + client.Connect(); + + Assert.Throws(() => client.UploadFile( + new OneByteStream(new MemoryStream("Hello World"u8.ToArray())), + "path.txt")); + } + +#pragma warning disable IDE0022 // Use block body for method +#pragma warning disable IDE0025 // Use block body for property +#pragma warning disable IDE0027 // Use block body for accessor +#pragma warning disable CS0067 // event is unused + + private class MySession(ConnectionInfo connectionInfo) : ISession + { + public IConnectionInfo ConnectionInfo => connectionInfo; + + public event EventHandler> ChannelCloseReceived; + public event EventHandler> ChannelDataReceived; + public event EventHandler> ChannelEofReceived; + public event EventHandler> ChannelExtendedDataReceived; + public event EventHandler> ChannelFailureReceived; + public event EventHandler> ChannelOpenConfirmationReceived; + public event EventHandler> ChannelOpenFailureReceived; + public event EventHandler> ChannelOpenReceived; + public event EventHandler> ChannelRequestReceived; + public event EventHandler> ChannelSuccessReceived; + public event EventHandler> ChannelWindowAdjustReceived; + public event EventHandler Disconnected; + public event EventHandler ErrorOccured; + public event EventHandler ServerIdentificationReceived; + public event EventHandler HostKeyReceived; + public event EventHandler> RequestSuccessReceived; + public event EventHandler> RequestFailureReceived; + public event EventHandler> UserAuthenticationBannerReceived; + + private uint _numRequests; + private int _numWriteRequests; + + public void SendMessage(Message message) + { + // Initialisation sequence for SFTP session + + if (message is ChannelOpenMessage) + { + ChannelOpenConfirmationReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelOpenConfirmationMessage(0, int.MaxValue, int.MaxValue, 0))); + } + else if (message is ChannelRequestMessage) + { + ChannelSuccessReceived?.Invoke( + this, + new MessageEventArgs(new ChannelSuccessMessage(0))); + } + else if (message is ChannelDataMessage dataMsg) + { + if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Init) + { + ChannelDataReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelDataMessage(0, new SftpVersionResponse() { Version = 3 }.GetBytes()))); + } + else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.RealPath) + { + ChannelDataReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelDataMessage(0, + new SftpNameResponse(3, Encoding.UTF8) + { + ResponseId = ++_numRequests, + Files = [new("thepath", new SftpFileAttributes(default, default, default, default, default, default, default))] + }.GetBytes()))); + } + else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Open) + { + ChannelDataReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelDataMessage(0, + new SftpHandleResponse(3) + { + ResponseId = ++_numRequests, + Handle = "file"u8.ToArray() + }.GetBytes()))); + } + + // --------- The actual interesting part of all of this --------- + // + else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Write) + { + // Fail the 5th write request + var statusCode = ++_numWriteRequests == 5 ? StatusCodes.PermissionDenied : StatusCodes.Ok; + var responseId = ++_numRequests; + + // Dispatch the responses on a different thread to simulate reality. + _ = Task.Run(() => + { + ChannelDataReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelDataMessage(0, + new SftpStatusResponse(3) + { + ResponseId = responseId, + StatusCode = statusCode + }.GetBytes()))); + }); + } + // + // -------------------------------------------------------------- + } + } + + public bool IsConnected => false; + + public SemaphoreSlim SessionSemaphore { get; } = new(1); + + public IChannelSession CreateChannelSession() => new ChannelSession(this, 0, int.MaxValue, int.MaxValue); + + public WaitHandle MessageListenerCompleted => throw new NotImplementedException(); + + public void Connect() + { + } + + public Task ConnectAsync(CancellationToken cancellationToken) => throw new NotImplementedException(); + + public IChannelDirectTcpip CreateChannelDirectTcpip() => throw new NotImplementedException(); + + public IChannelForwardedTcpip CreateChannelForwardedTcpip(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize) + => throw new NotImplementedException(); + + public void Dispose() + { + } + + public void OnDisconnecting() + { + } + + public void Disconnect() => throw new NotImplementedException(); + + public void RegisterMessage(string messageName) => throw new NotImplementedException(); + + public bool TrySendMessage(Message message) => throw new NotImplementedException(); + + public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout, out Exception exception) => throw new NotImplementedException(); + + public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException(); + + public void UnRegisterMessage(string messageName) => throw new NotImplementedException(); + + public void WaitOnHandle(WaitHandle waitHandle) + { + } + + public void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException(); + } + + private class OneByteStream : Stream + { + private readonly Stream _stream; + + public OneByteStream(Stream stream) + { + _stream = stream; + } + + public override bool CanRead => _stream.CanRead; + + public override bool CanSeek => throw new NotImplementedException(); + + public override bool CanWrite => throw new NotImplementedException(); + + public override long Length => _stream.Length; + + public override long Position + { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + public override void Flush() => throw new NotImplementedException(); + + public override int Read(byte[] buffer, int offset, int count) + { + return _stream.Read(buffer, offset, Math.Min(1, count)); + } + + public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException(); + + public override void SetLength(long value) => throw new NotImplementedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException(); + } + } +} From fc7b437f8dcbc5f0e0721d4111895ed0b186f183 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Tue, 27 May 2025 08:14:20 +0200 Subject: [PATCH 4/4] x --- test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs b/test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs index 98c96ea83..5bb4c916e 100644 --- a/test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs +++ b/test/Renci.SshNet.Tests/Classes/SftpClientTest.UploadFile.cs @@ -1,6 +1,5 @@ using System; using System.IO; -using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks;