From f1e28f8a88ebe24538d2f71752c55f97262ccfa1 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Wed, 17 Mar 2021 11:03:24 +0000 Subject: [PATCH 1/7] Improve benchmarks and errors --- client.go | 8 +- client_integration_test.go | 152 ++++++----- conn.go | 8 +- example_test.go | 2 +- go.sum | 3 + packet.go | 6 +- packet_test.go | 500 ++++++++++++++++++++++++------------- server_integration_test.go | 34 ++- sftp.go | 8 +- 9 files changed, 469 insertions(+), 252 deletions(-) diff --git a/client.go b/client.go index e4e98869..0f5b865c 100644 --- a/client.go +++ b/client.go @@ -43,10 +43,10 @@ type ClientOption func(*Client) error func MaxPacketChecked(size int) ClientOption { return func(c *Client) error { if size < 1 { - return errors.Errorf("size must be greater or equal to 1") + return errors.New("size must be greater or equal to 1") } if size > 32768 { - return errors.Errorf("sizes larger than 32KB might not work with all servers") + return errors.New("sizes larger than 32KB might not work with all servers") } c.maxPacket = size return nil @@ -65,7 +65,7 @@ func MaxPacketChecked(size int) ClientOption { func MaxPacketUnchecked(size int) ClientOption { return func(c *Client) error { if size < 1 { - return errors.Errorf("size must be greater or equal to 1") + return errors.New("size must be greater or equal to 1") } c.maxPacket = size return nil @@ -90,7 +90,7 @@ func MaxPacket(size int) ClientOption { func MaxConcurrentRequestsPerFile(n int) ClientOption { return func(c *Client) error { if n < 1 { - return errors.Errorf("n must be greater or equal to 1") + return errors.New("n must be greater or equal to 1") } c.maxConcurrentRequests = n return nil diff --git a/client_integration_test.go b/client_integration_test.go index f604819f..9c153a74 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -6,7 +6,6 @@ package sftp import ( "bytes" "crypto/sha1" - "encoding" "errors" "io" "io/ioutil" @@ -1490,26 +1489,48 @@ func TestClientReadFrom(t *testing.T) { var errFakeNet = errors.New("Fake network issue") func TestClientReadFromDeadlock(t *testing.T) { - clientWriteDeadlock(t, 1, func(f *File) { - b := make([]byte, 32768*4) - content := bytes.NewReader(b) - _, err := f.ReadFrom(content) - if err != errFakeNet { - t.Fatal("Didn't recieve correct error:", err) - } - }) + for i := 0; i < 5; i++ { + clientWriteDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) + content := bytes.NewReader(b) + _, err := f.ReadFrom(content) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't recieve correct error:", err) + } + }) + } } // Write has exact same problem func TestClientWriteDeadlock(t *testing.T) { - clientWriteDeadlock(t, 1, func(f *File) { - b := make([]byte, 32768*4) + for i := 0; i < 5; i++ { + clientWriteDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) - _, err := f.Write(b) - if err != errFakeNet { - t.Fatal("Didn't recieve correct error:", err) - } - }) + _, err := f.Write(b) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't recieve correct error:", err) + } + }) + } +} + +type timeBombWriter struct { + count int + w io.WriteCloser +} + +func (w *timeBombWriter) Write(b []byte) (int, error) { + if w.count < 1 { + return 0, errFakeNet + } + + w.count-- + return w.w.Write(b) +} + +func (w *timeBombWriter) Close() error { + return w.w.Close() } // shared body for both previous tests @@ -1534,20 +1555,13 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) { } defer w.Close() - // Override sendPacket with failing version - // Replicates network error/drop part way through (after 1 good packet) - count := 0 - sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error { - count++ - if count > N { - return errFakeNet - } - return sendPacket(w, m) + // Override the clienConn Writer with a failing version + // Replicates network error/drop part way through (after N good writes) + wrap := sftp.clientConn.conn.WriteCloser + sftp.clientConn.conn.WriteCloser = &timeBombWriter{ + count: N, + w: wrap, } - sftp.clientConn.conn.sendPacketTest = sendPacketTest - defer func() { - sftp.clientConn.conn.sendPacketTest = nil - }() // this locked (before the fix) badfunc(w) @@ -1555,27 +1569,31 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) { // Read/WriteTo has this issue as well func TestClientReadDeadlock(t *testing.T) { - clientReadDeadlock(t, 1, func(f *File) { - b := make([]byte, 32768*4) + for i := 0; i < 5; i++ { + clientReadDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) - _, err := f.Read(b) - if err != errFakeNet { - t.Fatal("Didn't recieve correct error:", err) - } - }) + _, err := f.Read(b) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't recieve correct error:", err) + } + }) + } } func TestClientWriteToDeadlock(t *testing.T) { - clientReadDeadlock(t, 2, func(f *File) { - b := make([]byte, 32768*4) + for i := 0; i < 5; i++ { + clientReadDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) - buf := bytes.NewBuffer(b) + buf := bytes.NewBuffer(b) - _, err := f.WriteTo(buf) - if err != errFakeNet { - t.Fatal("Didn't recieve correct error:", err) - } - }) + _, err := f.WriteTo(buf) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't recieve correct error:", err) + } + }) + } } func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) { @@ -1611,22 +1629,14 @@ func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) { } defer r.Close() - // Override sendPacket with failing version - // Replicates network error/drop part way through (after 1 good packet) - count := 0 - sendPacketTest := func(w io.Writer, m encoding.BinaryMarshaler) error { - count++ - if count > N { - return errFakeNet - } - return sendPacket(w, m) + // Override the clienConn Writer with a failing version + // Replicates network error/drop part way through (after N good writes) + wrap := sftp.clientConn.conn.WriteCloser + sftp.clientConn.conn.WriteCloser = &timeBombWriter{ + count: N, + w: wrap, } - sftp.clientConn.conn.sendPacketTest = sendPacketTest - defer func() { - sftp.clientConn.conn.sendPacketTest = nil - }() - // this locked (before the fix) badfunc(r) } @@ -2444,6 +2454,28 @@ func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) { benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond) } +// writeToBuffer implements the relevant parts of bytes.Buffer, +// but does not release its internal buffer when Reset. +// +// Release its internal memory when Reset is good for avoiding memory leaks, +// but not great for memory benchmarks, as this fills up a lot of irrelevant allocations. +type writeToBuffer struct { + b []byte +} + +func (w *writeToBuffer) Len() int { + return len(w.b) +} + +func (w *writeToBuffer) Reset() { + w.b = w.b[:0] +} + +func (w *writeToBuffer) Write(b []byte) (int, error) { + w.b = append(w.b, b...) + return len(b), nil +} + func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { size := 10*1024*1024 + 123 // ~10MiB @@ -2466,7 +2498,9 @@ func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { b.ResetTimer() b.SetBytes(int64(size)) - buf := new(bytes.Buffer) + buf := &writeToBuffer{ + b: make([]byte, 0, size), + } for i := 0; i < b.N; i++ { buf.Reset() diff --git a/conn.go b/conn.go index 952a2be4..de08a63a 100644 --- a/conn.go +++ b/conn.go @@ -16,8 +16,6 @@ type conn struct { // this is the same allocator used in packet manager alloc *allocator sync.Mutex // used to serialise writes to sendPacket - // sendPacketTest is needed to replicate packet issues in testing - sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error } // the orderID is used in server mode if the allocator is enabled. @@ -29,9 +27,7 @@ func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { c.Lock() defer c.Unlock() - if c.sendPacketTest != nil { - return c.sendPacketTest(c, m) - } + return sendPacket(c, m) } @@ -91,7 +87,7 @@ func (c *clientConn) recv() error { // This is an unexpected occurrence. Send the error // back to all listeners so that they terminate // gracefully. - return errors.Errorf("sid not found: %v", sid) + return errors.Errorf("sid not found: %d", sid) } ch <- result{typ: typ, data: data} diff --git a/example_test.go b/example_test.go index cf859298..0093eb0c 100644 --- a/example_test.go +++ b/example_test.go @@ -131,7 +131,7 @@ func ExampleClient_Mkdir_parents() { fi, err = client.Stat(parents) if err == nil { if !fi.IsDir() { - return fmt.Errorf("File exists: %s", parents) + return fmt.Errorf("file exists: %s", parents) } } } diff --git a/go.sum b/go.sum index 413dc5fe..c756602c 100644 --- a/go.sum +++ b/go.sum @@ -15,9 +15,12 @@ golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWP golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/packet.go b/packet.go index 4a686355..faf5ffd8 100644 --- a/packet.go +++ b/packet.go @@ -133,7 +133,7 @@ func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err erro func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { header, payload, err := marshalPacket(m) if err != nil { - return errors.Errorf("binary marshaller failed: %v", err) + return errors.Wrap(err, "binary marshaller failed") } length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start @@ -146,12 +146,12 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { binary.BigEndian.PutUint32(header[:4], uint32(length)) if _, err := w.Write(header); err != nil { - return errors.Errorf("failed to send packet: %v", err) + return errors.Wrap(err, "failed to send packet") } if len(payload) > 0 { if _, err := w.Write(payload); err != nil { - return errors.Errorf("failed to send packet payload: %v", err) + return errors.Wrap(err, "failed to send packet payload") } } diff --git a/packet_test.go b/packet_test.go index 976f66fc..505773b5 100644 --- a/packet_test.go +++ b/packet_test.go @@ -3,120 +3,174 @@ package sftp import ( "bytes" "encoding" + "errors" + "io/ioutil" "os" "testing" ) -var marshalUint32Tests = []struct { - v uint32 - want []byte -}{ - {1, []byte{0, 0, 0, 1}}, - {256, []byte{0, 0, 1, 0}}, - {^uint32(0), []byte{255, 255, 255, 255}}, -} - func TestMarshalUint32(t *testing.T) { - for _, tt := range marshalUint32Tests { + var tests = []struct { + v uint32 + want []byte + }{ + {0, []byte{0, 0, 0, 0}}, + {42, []byte{0, 0, 0, 42}}, + {42 << 8, []byte{0, 0, 42, 0}}, + {42 << 16, []byte{0, 42, 0, 0}}, + {42 << 24, []byte{42, 0, 0, 0}}, + {^uint32(0), []byte{255, 255, 255, 255}}, + } + + for _, tt := range tests { got := marshalUint32(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshalUint32(%d): want %v, got %v", tt.v, tt.want, got) + t.Errorf("marshalUint32(%d) = %#v, want %#v", tt.v, got, tt.want) } } } -var marshalUint64Tests = []struct { - v uint64 - want []byte -}{ - {1, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}}, - {256, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0}}, - {^uint64(0), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, - {1 << 32, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, -} - func TestMarshalUint64(t *testing.T) { - for _, tt := range marshalUint64Tests { + var tests = []struct { + v uint64 + want []byte + }{ + {0, []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {42, []byte{0, 0, 0, 0, 0, 0, 0, 42}}, + {42 << 8, []byte{0, 0, 0, 0, 0, 0, 42, 0}}, + {42 << 16, []byte{0, 0, 0, 0, 0, 42, 0, 0}}, + {42 << 24, []byte{0, 0, 0, 0, 42, 0, 0, 0}}, + {42 << 32, []byte{0, 0, 0, 42, 0, 0, 0, 0}}, + {42 << 40, []byte{0, 0, 42, 0, 0, 0, 0, 0}}, + {42 << 48, []byte{0, 42, 0, 0, 0, 0, 0, 0}}, + {42 << 56, []byte{42, 0, 0, 0, 0, 0, 0, 0}}, + {^uint64(0), []byte{255, 255, 255, 255, 255, 255, 255, 255}}, + } + + for _, tt := range tests { got := marshalUint64(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshalUint64(%d): want %#v, got %#v", tt.v, tt.want, got) + t.Errorf("marshalUint64(%d) = %#v, want %#v", tt.v, got, tt.want) } } } -var marshalStringTests = []struct { - v string - want []byte -}{ - {"", []byte{0, 0, 0, 0}}, - {"/foo", []byte{0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f}}, -} - func TestMarshalString(t *testing.T) { - for _, tt := range marshalStringTests { + var tests = []struct { + v string + want []byte + }{ + {"", []byte{0, 0, 0, 0}}, + {"/", []byte{0x0, 0x0, 0x0, 0x01, '/'}}, + {"/foo", []byte{0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o'}}, + {"\x00bar", []byte{0x0, 0x0, 0x0, 0x4, 0, 'b', 'a', 'r'}}, + {"b\x00ar", []byte{0x0, 0x0, 0x0, 0x4, 'b', 0, 'a', 'r'}}, + {"ba\x00r", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 0, 'r'}}, + {"bar\x00", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 'r', 0}}, + } + + for _, tt := range tests { got := marshalString(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshalString(%q): want %#v, got %#v", tt.v, tt.want, got) + t.Errorf("marshalString(%q) = %#v, want %#v", tt.v, got, tt.want) } } } -var marshalTests = []struct { - v interface{} - want []byte -}{ - {uint8(1), []byte{1}}, - {byte(1), []byte{1}}, - {uint32(1), []byte{0, 0, 0, 1}}, - {uint64(1), []byte{0, 0, 0, 0, 0, 0, 0, 1}}, - {"foo", []byte{0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f}}, - {[]uint32{1, 2, 3, 4}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x4}}, -} - func TestMarshal(t *testing.T) { - for _, tt := range marshalTests { + type Struct struct { + X, Y, Z uint32 + } + + var tests = []struct { + v interface{} + want []byte + }{ + {uint8(42), []byte{42}}, + {uint32(42 << 8), []byte{0, 0, 42, 0}}, + {uint64(42 << 32), []byte{0, 0, 0, 42, 0, 0, 0, 0}}, + {"foo", []byte{0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o'}}, + {Struct{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}}, + {[]uint32{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}}, + } + + for _, tt := range tests { got := marshal(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshal(%v): want %#v, got %#v", tt.v, tt.want, got) + t.Errorf("marshal(%#v) = %#v, want %#v", tt.v, got, tt.want) } } } -var unmarshalUint32Tests = []struct { - b []byte - want uint32 - rest []byte -}{ - {[]byte{0, 0, 0, 0}, 0, nil}, - {[]byte{0, 0, 1, 0}, 256, nil}, - {[]byte{255, 0, 0, 255}, 4278190335, nil}, -} - func TestUnmarshalUint32(t *testing.T) { - for _, tt := range unmarshalUint32Tests { - got, rest := unmarshalUint32(tt.b) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("unmarshalUint32(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) - } + testBuffer := []byte{ + 0, 0, 0, 0, + 0, 0, 0, 42, + 0, 0, 42, 0, + 0, 42, 0, 0, + 42, 0, 0, 0, + 255, 0, 0, 254, } -} -var unmarshalUint64Tests = []struct { - b []byte - want uint64 - rest []byte -}{ - {[]byte{0, 0, 0, 0, 0, 0, 0, 0}, 0, nil}, - {[]byte{0, 0, 0, 0, 0, 0, 1, 0}, 256, nil}, - {[]byte{255, 0, 0, 0, 0, 0, 0, 255}, 18374686479671623935, nil}, + var wants = []uint32{ + 0, + 42, + 42 << 8, + 42 << 16, + 42 << 24, + 255<<24 | 254, + } + + var i int + for len(testBuffer) > 0 { + got, rest := unmarshalUint32(testBuffer) + + if got != wants[i] { + t.Fatalf("unmarshalUint32(%#v) = %d, want %d", testBuffer[:4], got, wants[i]) + } + + i++ + testBuffer = rest + } } func TestUnmarshalUint64(t *testing.T) { - for _, tt := range unmarshalUint64Tests { - got, rest := unmarshalUint64(tt.b) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("unmarshalUint64(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + testBuffer := []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 42, + 0, 0, 0, 0, 0, 0, 42, 0, + 0, 0, 0, 0, 0, 42, 0, 0, + 0, 0, 0, 0, 42, 0, 0, 0, + 0, 0, 0, 42, 0, 0, 0, 0, + 0, 0, 42, 0, 0, 0, 0, 0, + 0, 42, 0, 0, 0, 0, 0, 0, + 42, 0, 0, 0, 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 254, + } + + var wants = []uint64{ + 0, + 42, + 42 << 8, + 42 << 16, + 42 << 24, + 42 << 32, + 42 << 40, + 42 << 48, + 42 << 56, + 255<<56 | 254, + } + + var i int + for len(testBuffer) > 0 { + got, rest := unmarshalUint64(testBuffer) + + if got != wants[i] { + t.Fatalf("unmarshalUint64(%#v) = %d, want %d", testBuffer[:8], got, wants[i]) } + + i++ + testBuffer = rest } } @@ -130,85 +184,193 @@ var unmarshalStringTests = []struct { } func TestUnmarshalString(t *testing.T) { - for _, tt := range unmarshalStringTests { - got, rest := unmarshalString(tt.b) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("unmarshalUint64(%v): want %q, %#v, got %q, %#v", tt.b, tt.want, tt.rest, got, rest) + testBuffer := []byte{ + 0, 0, 0, 0, + 0, 0, 0, 1, '/', + 0, 0, 0, 4, '/', 'f', 'o', 'o', + 0, 0, 0, 4, 0, 'b', 'a', 'r', + 0, 0, 0, 4, 'b', 0, 'a', 'r', + 0, 0, 0, 4, 'b', 'a', 0, 'r', + 0, 0, 0, 4, 'b', 'a', 'r', 0, + } + + var wants = []string{ + "", + "/", + "/foo", + "\x00bar", + "b\x00ar", + "ba\x00r", + "bar\x00", + } + + var i int + for len(testBuffer) > 0 { + got, rest := unmarshalString(testBuffer) + + if got != wants[i] { + t.Fatalf("unmarshalUint64(%#v...) = %q, want %q", testBuffer[:4], got, wants[i]) } + + i++ + testBuffer = rest } } -var sendPacketTests = []struct { - p encoding.BinaryMarshaler - want []byte -}{ - {&sshFxInitPacket{ - Version: 3, - Extensions: []extensionPair{ - {"posix-rename@openssh.com", "1"}, - }, - }, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, +type nopCloserBuffer struct { + bytes.Buffer +} - {&sshFxpOpenPacket{ - ID: 1, - Path: "/foo", - Pflags: flags(os.O_RDONLY), - }, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, - - {&sshFxpWritePacket{ - ID: 124, - Handle: "foo", - Offset: 13, - Length: uint32(len([]byte("bar"))), - Data: []byte("bar"), - }, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}}, - - {&sshFxpSetstatPacket{ - ID: 31, - Path: "/bar", - Flags: flags(os.O_WRONLY), - Attrs: struct { - UID uint32 - GID uint32 - }{1000, 100}, - }, []byte{0x0, 0x0, 0x0, 0x19, 0x9, 0x0, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x3, 0xe8, 0x0, 0x0, 0x0, 0x64}}, +func (*nopCloserBuffer) Close() error { + return nil } func TestSendPacket(t *testing.T) { - for _, tt := range sendPacketTests { - var w bytes.Buffer - sendPacket(&w, tt.p) - if got := w.Bytes(); !bytes.Equal(tt.want, got) { - t.Errorf("sendPacket(%v): want %#v, got %#v", tt.p, tt.want, got) + var tests = []struct { + packet encoding.BinaryMarshaler + want []byte + }{ + { + packet: &sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x26, + 0x1, + 0x0, 0x0, 0x0, 0x3, + 0x0, 0x0, 0x0, 0x18, + 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x0, 0x0, 0x0, 0x1, + '1', + }, + }, + { + packet: &sshFxpOpenPacket{ + ID: 1, + Path: "/foo", + Pflags: flags(os.O_RDONLY), + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x15, + 0x3, + 0x0, 0x0, 0x0, 0x1, + 0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o', + 0x0, 0x0, 0x0, 0x1, + 0x0, 0x0, 0x0, 0x0, + }, + }, + { + packet: &sshFxpWritePacket{ + ID: 124, + Handle: "foo", + Offset: 13, + Length: uint32(len("bar")), + Data: []byte("bar"), + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x1b, + 0x6, + 0x0, 0x0, 0x0, 0x7c, + 0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o', + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, + 0x0, 0x0, 0x0, 0x3, 'b', 'a', 'r', + }, + }, + { + packet: &sshFxpSetstatPacket{ + ID: 31, + Path: "/bar", + Flags: sshFileXferAttrUIDGID, + Attrs: struct { + UID uint32 + GID uint32 + }{ + UID: 1000, + GID: 100, + }, + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x19, + 0x9, + 0x0, 0x0, 0x0, 0x1f, + 0x0, 0x0, 0x0, 0x4, '/', 'b', 'a', 'r', + 0x0, 0x0, 0x0, 0x2, + 0x0, 0x0, 0x3, 0xe8, + 0x0, 0x0, 0x0, 0x64, + }, + }, + } + + for _, tt := range tests { + b := new(bytes.Buffer) + sendPacket(b, tt.packet) + if got := b.Bytes(); !bytes.Equal(tt.want, got) { + t.Errorf("sendPacket(%v): got %x want %x", tt.packet, tt.want, got) } } } -func sp(p encoding.BinaryMarshaler) []byte { - var w bytes.Buffer - sendPacket(&w, p) - return w.Bytes() +func sp(data encoding.BinaryMarshaler) []byte { + b := new(bytes.Buffer) + sendPacket(b, data) + return b.Bytes() } -var recvPacketTests = []struct { - b []byte - want uint8 - rest []byte -}{ - {sp(&sshFxInitPacket{ - Version: 3, - Extensions: []extensionPair{ - {"posix-rename@openssh.com", "1"}, +func TestRecvPacket(t *testing.T) { + var recvPacketTests = []struct { + b []byte + + want uint8 + body []byte + wantErr error + }{ + { + b: sp(&sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }), + want: sshFxpInit, + body: []byte{ + 0x0, 0x0, 0x0, 0x3, + 0x0, 0x0, 0x0, 0x18, + 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x0, 0x0, 0x0, 0x01, + '1', + }, }, - }), sshFxpInit, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, -} + { + b: []byte{ + 0x0, 0x0, 0x0, 0x0, + }, + wantErr: errShortPacket, + }, + } -func TestRecvPacket(t *testing.T) { for _, tt := range recvPacketTests { r := bytes.NewReader(tt.b) - got, rest, _ := recvPacket(r, nil, 0) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + + got, body, err := recvPacket(r, nil, 0) + if tt.wantErr == nil { + if err != nil { + t.Fatalf("recvPacket(%#v): unexpected error: %v", tt.b, err) + } + } else { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("recvPacket(%#v) = %v, want %v", tt.b, err, tt.wantErr) + } + } + + if got != tt.want { + t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, got, tt.want) + } + + if !bytes.Equal(body, tt.body) { + t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, body, tt.body) } } } @@ -297,49 +459,49 @@ func TestSSHFxpOpenPackethasPflags(t *testing.T) { } } -func BenchmarkMarshalInit(b *testing.B) { +func benchMarshal(b *testing.B, packet encoding.BinaryMarshaler) { for i := 0; i < b.N; i++ { - sp(&sshFxInitPacket{ - Version: 3, - Extensions: []extensionPair{ - {"posix-rename@openssh.com", "1"}, - }, - }) + sendPacket(ioutil.Discard, packet) } } +func BenchmarkMarshalInit(b *testing.B) { + benchMarshal(b, &sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }) +} + func BenchmarkMarshalOpen(b *testing.B) { - for i := 0; i < b.N; i++ { - sp(&sshFxpOpenPacket{ - ID: 1, - Path: "/home/test/some/random/path", - Pflags: flags(os.O_RDONLY), - }) - } + benchMarshal(b, &sshFxpOpenPacket{ + ID: 1, + Path: "/home/test/some/random/path", + Pflags: flags(os.O_RDONLY), + }) } func BenchmarkMarshalWriteWorstCase(b *testing.B) { data := make([]byte, 32*1024) - for i := 0; i < b.N; i++ { - sp(&sshFxpWritePacket{ - ID: 1, - Handle: "someopaquehandle", - Offset: 0, - Length: uint32(len(data)), - Data: data, - }) - } + + benchMarshal(b, &sshFxpWritePacket{ + ID: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) } func BenchmarkMarshalWrite1k(b *testing.B) { - data := make([]byte, 1024) - for i := 0; i < b.N; i++ { - sp(&sshFxpWritePacket{ - ID: 1, - Handle: "someopaquehandle", - Offset: 0, - Length: uint32(len(data)), - Data: data, - }) - } + data := make([]byte, 1025) + + benchMarshal(b, &sshFxpWritePacket{ + ID: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) } diff --git a/server_integration_test.go b/server_integration_test.go index c4fb3b99..92e28570 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -408,16 +408,16 @@ func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, s func makeDummyKey() (string, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader) if err != nil { - return "", fmt.Errorf("cannot generate key: %v", err) + return "", fmt.Errorf("cannot generate key: %w", err) } der, err := x509.MarshalECPrivateKey(priv) if err != nil { - return "", fmt.Errorf("cannot marshal key: %v", err) + return "", fmt.Errorf("cannot marshal key: %w", err) } block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der} f, err := ioutil.TempFile("", "sftp-test-key-") if err != nil { - return "", fmt.Errorf("cannot create temp file: %v", err) + return "", fmt.Errorf("cannot create temp file: %w", err) } defer func() { if f != nil { @@ -426,16 +426,34 @@ func makeDummyKey() (string, error) { } }() if err := pem.Encode(f, block); err != nil { - return "", fmt.Errorf("cannot write key: %v", err) + return "", fmt.Errorf("cannot write key: %w", err) } if err := f.Close(); err != nil { - return "", fmt.Errorf("error closing key file: %v", err) + return "", fmt.Errorf("error closing key file: %w", err) } path := f.Name() f = nil return path, nil } +type execError struct { + path string + stderr string + err error +} + +func (e *execError) Error() string { + return fmt.Sprintf("%s: %v: %s", e.path, e.err, e.stderr) +} + +func (e *execError) Unwrap() error { + return e.err +} + +func (e *execError) Cause() error { + return e.err +} + func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) { // if sftp client binary is unavailable, skip test if _, err := os.Stat(*testSftpClientBin); err != nil { @@ -471,7 +489,11 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i } err = cmd.Wait() if err != nil { - err = fmt.Errorf("%v: %s", err, stderr.String()) + err = &execError{ + path: cmd.Path, + stderr: stderr.String(), + err: err, + } } return stdout.String(), err } diff --git a/sftp.go b/sftp.go index 912dff1d..13320832 100644 --- a/sftp.go +++ b/sftp.go @@ -200,15 +200,15 @@ func unimplementedPacketErr(u uint8) error { type unexpectedIDErr struct{ want, got uint32 } func (u *unexpectedIDErr) Error() string { - return fmt.Sprintf("sftp: unexpected id: want %v, got %v", u.want, u.got) + return fmt.Sprintf("sftp: unexpected id: want %d, got %d", u.want, u.got) } func unimplementedSeekWhence(whence int) error { - return errors.Errorf("sftp: unimplemented seek whence %v", whence) + return errors.Errorf("sftp: unimplemented seek whence %d", whence) } func unexpectedCount(want, got uint32) error { - return errors.Errorf("sftp: unexpected count: want %v, got %v", want, got) + return errors.Errorf("sftp: unexpected count: want %d, got %d", want, got) } type unexpectedVersionErr struct{ want, got uint32 } @@ -239,7 +239,7 @@ func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) return supportedExtension, nil } } - return sshExtensionPair{}, fmt.Errorf("Unsupported extension: %v", extensionName) + return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName) } // SetSFTPExtensions allows to customize the supported server extensions. From 39e1161d126076000870b4715872afd5abece710 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Wed, 17 Mar 2021 12:05:00 +0000 Subject: [PATCH 2/7] address my own code review comments --- client_integration_test.go | 10 +++++----- packet_test.go | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client_integration_test.go b/client_integration_test.go index 9c153a74..222b734a 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1569,7 +1569,7 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) { // Read/WriteTo has this issue as well func TestClientReadDeadlock(t *testing.T) { - for i := 0; i < 5; i++ { + for i := 0; i < 3; i++ { clientReadDeadlock(t, i, func(f *File) { b := make([]byte, 32768*4) @@ -1582,7 +1582,7 @@ func TestClientReadDeadlock(t *testing.T) { } func TestClientWriteToDeadlock(t *testing.T) { - for i := 0; i < 5; i++ { + for i := 0; i < 3; i++ { clientReadDeadlock(t, i, func(f *File) { b := make([]byte, 32768*4) @@ -2495,13 +2495,13 @@ func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { f.Write(data) f.Close() - b.ResetTimer() - b.SetBytes(int64(size)) - buf := &writeToBuffer{ b: make([]byte, 0, size), } + b.ResetTimer() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { buf.Reset() diff --git a/packet_test.go b/packet_test.go index 505773b5..c7deb5ab 100644 --- a/packet_test.go +++ b/packet_test.go @@ -217,14 +217,6 @@ func TestUnmarshalString(t *testing.T) { } } -type nopCloserBuffer struct { - bytes.Buffer -} - -func (*nopCloserBuffer) Close() error { - return nil -} - func TestSendPacket(t *testing.T) { var tests = []struct { packet encoding.BinaryMarshaler @@ -349,6 +341,12 @@ func TestRecvPacket(t *testing.T) { }, wantErr: errShortPacket, }, + { + b: []byte{ + 0xff, 0xff, 0xff, 0xff, + }, + wantErr: errLongPacket, + }, } for _, tt := range recvPacketTests { @@ -460,6 +458,8 @@ func TestSSHFxpOpenPackethasPflags(t *testing.T) { } func benchMarshal(b *testing.B, packet encoding.BinaryMarshaler) { + b.ResetTimer() + for i := 0; i < b.N; i++ { sendPacket(ioutil.Discard, packet) } From b22b9e472e3d6191a10bdba94a23c8000471bbbf Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Wed, 17 Mar 2021 13:18:17 +0000 Subject: [PATCH 3/7] remove writeToBuffer, the bytes.Buffer.Grow I saw in the memprofile was elsewhere --- client_integration_test.go | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/client_integration_test.go b/client_integration_test.go index 222b734a..5ee748f6 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -2454,28 +2454,6 @@ func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) { benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond) } -// writeToBuffer implements the relevant parts of bytes.Buffer, -// but does not release its internal buffer when Reset. -// -// Release its internal memory when Reset is good for avoiding memory leaks, -// but not great for memory benchmarks, as this fills up a lot of irrelevant allocations. -type writeToBuffer struct { - b []byte -} - -func (w *writeToBuffer) Len() int { - return len(w.b) -} - -func (w *writeToBuffer) Reset() { - w.b = w.b[:0] -} - -func (w *writeToBuffer) Write(b []byte) (int, error) { - w.b = append(w.b, b...) - return len(b), nil -} - func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { size := 10*1024*1024 + 123 // ~10MiB @@ -2495,9 +2473,7 @@ func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { f.Write(data) f.Close() - buf := &writeToBuffer{ - b: make([]byte, 0, size), - } + buf := new(bytes.Buffer) b.ResetTimer() b.SetBytes(int64(size)) From 325cdac782790fcecaf6741775659f67e9baeec0 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Wed, 17 Mar 2021 14:57:00 +0000 Subject: [PATCH 4/7] Add convenient benchmark Makefile rule --- Makefile | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Makefile b/Makefile index 0afad584..8bbaab19 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +.PHONY: integration integration_w_race benchmark + integration: go test -integration -v ./... go test -testserver -v ./... @@ -14,4 +16,9 @@ integration_w_race: go test -race -testserver -allocator -v ./... go test -race -integration -allocator -testserver -v ./... +COUNT ?= 1 +BENCHMARK_PATTERN ?= "." +benchmark: + go test -integration -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -memprofile memprofile.out -count=$(COUNT) + go tool pprof -svg -output=memprofile.svg memprofile.out From 32f98f3047d1ea14b206dbe684143dd96dc7dbf3 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Wed, 17 Mar 2021 20:17:15 +0000 Subject: [PATCH 5/7] split benchmark and benchmark_w_memprofile, include memprofile files to gitignore --- .gitignore | 3 +++ Makefile | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e1ec837c..caf2dca2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ server_standalone/server_standalone examples/*/id_rsa examples/*/id_rsa.pub + +memprofile.out +memprofile.svg diff --git a/Makefile b/Makefile index 8bbaab19..4d3a0079 100644 --- a/Makefile +++ b/Makefile @@ -20,5 +20,8 @@ COUNT ?= 1 BENCHMARK_PATTERN ?= "." benchmark: - go test -integration -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -memprofile memprofile.out -count=$(COUNT) + go test -integration -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) + +benchmark_w_memprofile: + go test -integration -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) -memprofile memprofile.out go tool pprof -svg -output=memprofile.svg memprofile.out From 91163e446307be402988ac460b3915757405f0ad Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Wed, 17 Mar 2021 20:17:55 +0000 Subject: [PATCH 6/7] errors.Errorf over fmt.Errorf --- sftp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sftp.go b/sftp.go index 13320832..7c7250f4 100644 --- a/sftp.go +++ b/sftp.go @@ -239,7 +239,7 @@ func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) return supportedExtension, nil } } - return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName) + return sshExtensionPair{}, errors.Errorf("unsupported extension: %s", extensionName) } // SetSFTPExtensions allows to customize the supported server extensions. From addaabd30b107e4667f5d7ef643e006b186c2848 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Wed, 17 Mar 2021 20:18:50 +0000 Subject: [PATCH 7/7] give the bytes.Buffer a preallocated slice to use for less variance --- client_integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client_integration_test.go b/client_integration_test.go index 5ee748f6..26d58a92 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -2473,7 +2473,7 @@ func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { f.Write(data) f.Close() - buf := new(bytes.Buffer) + buf := bytes.NewBuffer(make([]byte, 0, size)) b.ResetTimer() b.SetBytes(int64(size))