From 65f24bcee4c2de754878b2fe93027e2b8d057453 Mon Sep 17 00:00:00 2001 From: Hilari Moragrega Date: Mon, 21 Feb 2022 12:19:31 +0100 Subject: [PATCH] Improved test with CR feedback --- client.go | 4 +--- client_integration_test.go | 47 ++++++++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index f5bd39ee..d8c979a0 100644 --- a/client.go +++ b/client.go @@ -1180,9 +1180,7 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { written += int64(m) if wErr != nil { - if err == nil || err == io.EOF { - err = wErr - } + return written, wErr } } diff --git a/client_integration_test.go b/client_integration_test.go index d7c10408..1a855d42 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1249,10 +1249,19 @@ func TestClientReadSequential(t *testing.T) { } } -type writerFunc func(b []byte) (int, error) +type lastChunkErrSequentialWriter struct { + expected int + written int + writtenReturn int +} -func (f writerFunc) Write(b []byte) (int, error) { - return f(b) +func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) { + chunkSize := len(b) + w.written += chunkSize + if w.written == w.expected { + return w.writtenReturn, errors.New("test error") + } + return chunkSize, nil } func TestClientWriteSequential_WriterErr(t *testing.T) { @@ -1260,31 +1269,35 @@ func TestClientWriteSequential_WriterErr(t *testing.T) { defer cmd.Wait() defer sftp.Close() - sftp.disableConcurrentReads = true - d, err := ioutil.TempDir("", "sftptest-writesequential") + d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr") require.NoError(t, err) defer os.RemoveAll(d) - f, err := ioutil.TempFile(d, "write-sequential-test") + var ( + content = []byte("hello world") + shortWrite = 2 + ) + w := lastChunkErrSequentialWriter{ + expected: len(content), + writtenReturn: shortWrite, + } + + f, err := ioutil.TempFile(d, "write-sequential-writeerr-test") require.NoError(t, err) fname := f.Name() - content := []byte("hello world") - f.Write(content) - f.Close() + n, err := f.Write(content) + require.NoError(t, err) + require.Equal(t, n, len(content)) + require.NoError(t, f.Close()) sftpFile, err := sftp.Open(fname) require.NoError(t, err) defer sftpFile.Close() - want := errors.New("error writing") - n, got := io.Copy(writerFunc(func(b []byte) (int, error) { - return 10, want - }), sftpFile) - - require.Error(t, got) - assert.ErrorIs(t, want, got) - assert.Equal(t, int64(10), n) + gotWritten, gotErr := sftpFile.writeToSequential(&w) + require.NotErrorIs(t, io.EOF, gotErr) + require.Equal(t, int64(shortWrite), gotWritten) } func TestClientReadDir(t *testing.T) {