Skip to content

Commit

Permalink
writeToSequential: improve tests for write errors
Browse files Browse the repository at this point in the history
  • Loading branch information
drakkan committed Mar 3, 2022
1 parent 65f24bc commit ae00b32
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 28 deletions.
6 changes: 3 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1176,11 +1176,11 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
if n > 0 {
f.offset += int64(n)

m, wErr := w.Write(b[:n])
m, err := w.Write(b[:n])
written += int64(m)

if wErr != nil {
return written, wErr
if err != nil {
return written, err
}
}

Expand Down
48 changes: 23 additions & 25 deletions client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1249,55 +1249,53 @@ func TestClientReadSequential(t *testing.T) {
}
}

// this writer requires maxPacket = 2 and returns a short write error for the second write call
type lastChunkErrSequentialWriter struct {
expected int
written int
writtenReturn int
writeCounter int
}

func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) {
chunkSize := len(b)
w.written += chunkSize
if w.written == w.expected {
return w.writtenReturn, errors.New("test error")
if len(b) != 2 {
return 0, errors.New("this writer require maxPacket = 2, pleae set MaxPacketChecked(2)")
}
w.writeCounter++
switch w.writeCounter {
case 1:
return len(b), nil
default:
return 1, io.ErrShortWrite
}
return chunkSize, nil
}

func TestClientWriteSequential_WriterErr(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NODELAY)
func TestClientWriteSequentialWriterErr(t *testing.T) {
client, cmd := testClient(t, READONLY, NODELAY, MaxPacketChecked(2))
defer cmd.Wait()
defer sftp.Close()
defer client.Close()

d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr")
require.NoError(t, err)

defer os.RemoveAll(d)

var (
content = []byte("hello world")
shortWrite = 2
)
w := lastChunkErrSequentialWriter{
expected: len(content),
writtenReturn: shortWrite,
}
w := &lastChunkErrSequentialWriter{}

f, err := ioutil.TempFile(d, "write-sequential-writeerr-test")
require.NoError(t, err)
fname := f.Name()
n, err := f.Write(content)
_, err = f.Write([]byte("hello world"))
require.NoError(t, err)
require.Equal(t, n, len(content))
require.NoError(t, f.Close())

sftpFile, err := sftp.Open(fname)
sftpFile, err := client.Open(fname)
require.NoError(t, err)
defer sftpFile.Close()

gotWritten, gotErr := sftpFile.writeToSequential(&w)
require.NotErrorIs(t, io.EOF, gotErr)
require.Equal(t, int64(shortWrite), gotWritten)
written, err := sftpFile.writeToSequential(w)
assert.NotNil(t, err)
if written != 3 {
t.Errorf("sftpFile.Write() = %d, but expected 3", written)
}
assert.Equal(t, 2, w.writeCounter)
}

func TestClientReadDir(t *testing.T) {
Expand Down

0 comments on commit ae00b32

Please sign in to comment.