diff --git a/.travis.yml b/.travis.yml index 51dead07..e9490286 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,8 @@ go_import_path: github.com/pkg/sftp # current and previous stable releases, plus tip # remember to exclude previous and tip for macs below go: - - 1.12.x - 1.13.x + - 1.14.x - tip os: @@ -15,7 +15,7 @@ os: matrix: exclude: - os: osx - go: 1.12.x + go: 1.13.x - os: osx go: tip @@ -35,6 +35,12 @@ script: - go test -integration -v ./... - go test -testserver -v ./... - go test -integration -testserver -v ./... + - go test -integration -allocator -v ./... + - go test -testserver -allocator -v ./... + - go test -integration -testserver -allocator -v ./... - go test -race -integration -v ./... - go test -race -testserver -v ./... - go test -race -integration -testserver -v ./... + - go test -race -integration -allocator -v ./... + - go test -race -testserver -allocator -v ./... + - go test -race -integration -allocator -testserver -v ./... diff --git a/Makefile b/Makefile index 781fe1f5..0afad584 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,17 @@ integration: - go test -integration -v - go test -testserver -v - go test -integration -testserver -v + go test -integration -v ./... + go test -testserver -v ./... + go test -integration -testserver -v ./... + go test -integration -allocator -v ./... + go test -testserver -allocator -v ./... + go test -integration -testserver -allocator -v ./... integration_w_race: - go test -race -integration -v - go test -race -testserver -v - go test -race -integration -testserver -v + go test -race -integration -v ./... + go test -race -testserver -v ./... + go test -race -integration -testserver -v ./... + go test -race -integration -allocator -v ./... + go test -race -testserver -allocator -v ./... + go test -race -integration -allocator -testserver -v ./... diff --git a/allocator.go b/allocator.go new file mode 100644 index 00000000..3e67e543 --- /dev/null +++ b/allocator.go @@ -0,0 +1,96 @@ +package sftp + +import ( + "sync" +) + +type allocator struct { + sync.Mutex + available [][]byte + // map key is the request order + used map[uint32][][]byte +} + +func newAllocator() *allocator { + return &allocator{ + // micro optimization: initialize available pages with an initial capacity + available: make([][]byte, 0, SftpServerWorkerCount*2), + used: make(map[uint32][][]byte), + } +} + +// GetPage returns a previously allocated and unused []byte or create a new one. +// The slice have a fixed size = maxMsgLength, this value is suitable for both +// receiving new packets and reading the files to serve +func (a *allocator) GetPage(requestOrderID uint32) []byte { + a.Lock() + defer a.Unlock() + + var result []byte + + // get an available page and remove it from the available ones. + if len(a.available) > 0 { + truncLength := len(a.available) - 1 + result = a.available[truncLength] + + a.available[truncLength] = nil // clear out the internal pointer + a.available = a.available[:truncLength] // truncate the slice + } + + // no preallocated slice found, just allocate a new one + if result == nil { + result = make([]byte, maxMsgLength) + } + + // put result in used pages + a.used[requestOrderID] = append(a.used[requestOrderID], result) + + return result +} + +// ReleasePages marks unused all pages in use for the given requestID +func (a *allocator) ReleasePages(requestOrderID uint32) { + a.Lock() + defer a.Unlock() + + if used := a.used[requestOrderID]; len(used) > 0 { + a.available = append(a.available, used...) + } + delete(a.used, requestOrderID) +} + +// Free removes all the used and available pages. +// Call this method when the allocator is not needed anymore +func (a *allocator) Free() { + a.Lock() + defer a.Unlock() + + a.available = nil + a.used = make(map[uint32][][]byte) +} + +func (a *allocator) countUsedPages() int { + a.Lock() + defer a.Unlock() + + num := 0 + for _, p := range a.used { + num += len(p) + } + return num +} + +func (a *allocator) countAvailablePages() int { + a.Lock() + defer a.Unlock() + + return len(a.available) +} + +func (a *allocator) isRequestOrderIDUsed(requestOrderID uint32) bool { + a.Lock() + defer a.Unlock() + + _, ok := a.used[requestOrderID] + return ok +} diff --git a/allocator_test.go b/allocator_test.go new file mode 100644 index 00000000..74f4da1a --- /dev/null +++ b/allocator_test.go @@ -0,0 +1,135 @@ +package sftp + +import ( + "strconv" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllocator(t *testing.T) { + allocator := newAllocator() + // get a page for request order id 1 + page := allocator.GetPage(1) + page[1] = uint8(1) + assert.Equal(t, maxMsgLength, len(page)) + assert.Equal(t, 1, allocator.countUsedPages()) + // get another page for request order id 1, we now have 2 used pages + page = allocator.GetPage(1) + page[0] = uint8(2) + assert.Equal(t, 2, allocator.countUsedPages()) + // get another page for request order id 1, we now have 3 used pages + page = allocator.GetPage(1) + page[2] = uint8(3) + assert.Equal(t, 3, allocator.countUsedPages()) + // release the page for request order id 1, we now have 3 available pages + allocator.ReleasePages(1) + assert.NotContains(t, allocator.used, 1) + assert.Equal(t, 3, allocator.countAvailablePages()) + // get a page for request order id 2 + // we get the latest released page, let's verify that by checking the previously written values + // so we are sure we are reusing a previously allocated page + page = allocator.GetPage(2) + assert.Equal(t, uint8(3), page[2]) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 1, allocator.countUsedPages()) + page = allocator.GetPage(2) + assert.Equal(t, uint8(2), page[0]) + assert.Equal(t, 1, allocator.countAvailablePages()) + assert.Equal(t, 2, allocator.countUsedPages()) + page = allocator.GetPage(2) + assert.Equal(t, uint8(1), page[1]) + // we now have 3 used pages for request order id 2 and no available pages + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // release some request order id with no allocated pages, should have no effect + allocator.ReleasePages(1) + allocator.ReleasePages(3) + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // now get some pages for another request order id + allocator.GetPage(3) + // we now must have 3 used pages for request order id 2 and 1 used page for request order id 3 + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 4, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.True(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // get another page for request order id 3 + allocator.GetPage(3) + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 5, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.True(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // now release the pages for request order id 3 + allocator.ReleasePages(3) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + assert.False(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be not used") + // again check we are reusing previously allocated pages. + // We have written nothing to the 2 last requested page so release them and get the third one + allocator.ReleasePages(2) + assert.Equal(t, 5, allocator.countAvailablePages()) + assert.Equal(t, 0, allocator.countUsedPages()) + assert.False(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be not used") + allocator.GetPage(4) + allocator.GetPage(4) + page = allocator.GetPage(4) + assert.Equal(t, uint8(3), page[2]) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(4), "page with request order id 4 must be used") + // free the allocator + allocator.Free() + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 0, allocator.countUsedPages()) +} + +func BenchmarkAllocatorSerial(b *testing.B) { + allocator := newAllocator() + for i := 0; i < b.N; i++ { + benchAllocator(allocator, uint32(i)) + } +} + +func BenchmarkAllocatorParallel(b *testing.B) { + var counter uint32 + allocator := newAllocator() + for i := 1; i <= 8; i *= 2 { + b.Run(strconv.Itoa(i), func(b *testing.B) { + b.SetParallelism(i) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + benchAllocator(allocator, atomic.AddUint32(&counter, 1)) + } + }) + }) + } +} + +func benchAllocator(allocator *allocator, requestOrderID uint32) { + // simulates the page requested in recvPacket + allocator.GetPage(requestOrderID) + // simulates the page requested in fileget for downloads + allocator.GetPage(requestOrderID) + // release the allocated pages + allocator.ReleasePages(requestOrderID) +} + +// useful for debug +func printAllocatorContents(allocator *allocator) { + for o, u := range allocator.used { + debug("used order id: %v, values: %+v", o, u) + } + for _, v := range allocator.available { + debug("available, values: %+v", v) + } +} diff --git a/client.go b/client.go index 0d09d2a2..f5905cf9 100644 --- a/client.go +++ b/client.go @@ -214,7 +214,7 @@ func (c *Client) nextID() uint32 { } func (c *Client) recvVersion() error { - typ, data, err := c.recvPacket() + typ, data, err := c.recvPacket(0) if err != nil { return err } diff --git a/conn.go b/conn.go index cfbc1d21..0d8de601 100644 --- a/conn.go +++ b/conn.go @@ -13,13 +13,17 @@ import ( type conn struct { io.Reader io.WriteCloser + // this is the same allocator used in packet manager + alloc *allocator sync.Mutex // used to serialise writes to sendPacket // sendPacketTest is needed to replicate packet issues in testing sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error } -func (c *conn) recvPacket() (uint8, []byte, error) { - return recvPacket(c) +// the orderID is used in server mode if the allocator is enabled. +// For the client mode just pass 0 +func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { + return recvPacket(c, c.alloc, orderID) } func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { @@ -76,7 +80,7 @@ func (c *clientConn) recv() error { c.conn.Close() }() for { - typ, data, err := c.recvPacket() + typ, data, err := c.recvPacket(0) if err != nil { return err } diff --git a/packet-manager.go b/packet-manager.go index 45d29956..c870c378 100644 --- a/packet-manager.go +++ b/packet-manager.go @@ -18,6 +18,8 @@ type packetManager struct { sender packetSender // connection object working *sync.WaitGroup packetCount uint32 + // it is not nil if the allocator is enabled + alloc *allocator } type packetSender interface { @@ -44,6 +46,14 @@ func (s *packetManager) newOrderID() uint32 { return s.packetCount } +// returns the next orderID without incrementing it. +// This is used before receiving a new packet, with the allocator enabled, to associate +// the slice allocated for the received packet with the orderID that will be used to mark +// the allocated slices for reuse once the request is served +func (s *packetManager) getNextOrderID() uint32 { + return s.packetCount + 1 +} + type orderedRequest struct { requestPacket orderid uint32 @@ -174,6 +184,10 @@ func (s *packetManager) maybeSendPackets() { if in.orderID() == out.orderID() { debug("Sending packet: %v", out.id()) s.sender.sendPacket(out.(encoding.BinaryMarshaler)) + if s.alloc != nil { + // mark for reuse the slices allocated for this request + s.alloc.ReleasePages(in.orderID()) + } // pop off heads copy(s.incoming, s.incoming[1:]) // shift left s.incoming[len(s.incoming)-1] = nil // clear last diff --git a/packet.go b/packet.go index 7f55e542..dba34d2b 100644 --- a/packet.go +++ b/packet.go @@ -139,27 +139,34 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { return nil } -func recvPacket(r io.Reader) (uint8, []byte, error) { - var b = []byte{0, 0, 0, 0} - if _, err := io.ReadFull(r, b); err != nil { +func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) { + var b []byte + if alloc != nil { + b = alloc.GetPage(orderID) + } else { + b = make([]byte, 4) + } + if _, err := io.ReadFull(r, b[:4]); err != nil { return 0, nil, err } - l, _ := unmarshalUint32(b) - if l > maxMsgLength { - debug("recv packet %d bytes too long", l) + length, _ := unmarshalUint32(b) + if length > maxMsgLength { + debug("recv packet %d bytes too long", length) return 0, nil, errLongPacket } - b = make([]byte, l) - if _, err := io.ReadFull(r, b); err != nil { - debug("recv packet %d bytes: err %v", l, err) + if alloc == nil { + b = make([]byte, length) + } + if _, err := io.ReadFull(r, b[:length]); err != nil { + debug("recv packet %d bytes: err %v", length, err) return 0, nil, err } if debugDumpRxPacketBytes { - debug("recv packet: %s %d bytes %x", fxp(b[0]), l, b[1:]) + debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length]) } else if debugDumpRxPacket { - debug("recv packet: %s %d bytes", fxp(b[0]), l) + debug("recv packet: %s %d bytes", fxp(b[0]), length) } - return b[0], b[1:], nil + return b[0], b[1:length], nil } type extensionPair struct { @@ -584,10 +591,15 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpReadPacket) getDataSlice() []byte { +func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { dataLen := clamp(p.Len, maxTxPacket) + if alloc != nil { + // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in + // sshFxpDataPacket.MarshalBinary and sendPacket + return alloc.GetPage(orderID)[:dataLen] + } // we allocate a slice with a bigger capacity so we avoid a new allocation in sshFxpDataPacket.MarshalBinary - // and in sendPacket, we need 9 bytes in MarshalBinary and 4 bytes in sendPacket + // and in sendPacket, we need 9 bytes in MarshalBinary and 4 bytes in sendPacket. return make([]byte, dataLen, dataLen+9+4) } diff --git a/packet_test.go b/packet_test.go index 8378fb63..8b16be6e 100644 --- a/packet_test.go +++ b/packet_test.go @@ -206,7 +206,7 @@ var recvPacketTests = []struct { func TestRecvPacket(t *testing.T) { for _, tt := range recvPacketTests { r := bytes.NewReader(tt.b) - got, rest, _ := recvPacket(r) + got, rest, _ := recvPacket(r, nil, 0) if got != tt.want || !bytes.Equal(rest, tt.rest) { t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) } diff --git a/request-server.go b/request-server.go index d15f3a11..cb357e3b 100644 --- a/request-server.go +++ b/request-server.go @@ -32,21 +32,41 @@ type RequestServer struct { handleCount int } +// A RequestServerOption is a function which applies configuration to a RequestServer. +type RequestServerOption func(*RequestServer) + +// WithRSAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithRSAllocator() RequestServerOption { + return func(rs *RequestServer) { + alloc := newAllocator() + rs.pktMgr.alloc = alloc + rs.conn.alloc = alloc + } +} + // NewRequestServer creates/allocates/returns new RequestServer. -// Normally there there will be one server per user-session. -func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer { +// Normally there will be one server per user-session. +func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer { svrConn := &serverConn{ conn: conn{ Reader: rwc, WriteCloser: rwc, }, } - return &RequestServer{ + rs := &RequestServer{ serverConn: svrConn, Handlers: h, pktMgr: newPktMgr(svrConn), openRequests: make(map[string]*Request), } + + for _, o := range options { + o(rs) + } + return rs } // New Open packet/Request @@ -88,6 +108,11 @@ func (rs *RequestServer) Close() error { return rs.conn.Close() } // Serve requests for user session func (rs *RequestServer) Serve() error { + defer func() { + if rs.pktMgr.alloc != nil { + rs.pktMgr.alloc.Free() + } + }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() var wg sync.WaitGroup @@ -107,11 +132,11 @@ func (rs *RequestServer) Serve() error { var pktType uint8 var pktBytes []byte for { - pktType, pktBytes, err = rs.recvPacket() + pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID()) if err != nil { + // we don't care about releasing allocated pages here, the server will quit and the allocator freed break } - pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) if err != nil { switch errors.Cause(err) { @@ -158,6 +183,7 @@ func (rs *RequestServer) packetWorker( ctx context.Context, pktChan chan orderedRequest, ) error { for pkt := range pktChan { + orderID := pkt.orderID() if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok { if epkt.SpecificPacket != nil { pkt.requestPacket = epkt.SpecificPacket @@ -188,30 +214,30 @@ func (rs *RequestServer) packetWorker( rpkt = statusFromError(pkt, syscall.EBADF) } else { request = NewRequest("Stat", request.Filepath) - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) } case *sshFxpExtendedPacketPosixRename: request := NewRequest("Rename", pkt.Oldpath) request.Target = pkt.Newpath - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) case hasHandle: handle := pkt.getHandle() request, ok := rs.getRequest(handle) if !ok { rpkt = statusFromError(pkt, syscall.EBADF) } else { - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) } case hasPath: request := requestFromPacket(ctx, pkt) - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) request.close() default: rpkt = statusFromError(pkt, ErrSSHFxOpUnsupported) } rs.pktMgr.readyPacket( - rs.pktMgr.newOrderedResponse(rpkt, pkt.orderID())) + rs.pktMgr.newOrderedResponse(rpkt, orderID)) } return nil } @@ -232,7 +258,7 @@ func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket { // Makes sure we have a clean POSIX (/) absolute path to work with func cleanPath(p string) string { p = filepath.ToSlash(p) - if !filepath.IsAbs(p) { + if !path.IsAbs(p) { p = "/" + p } return path.Clean(p) diff --git a/request-server_test.go b/request-server_test.go index 6f06bd22..703b7d92 100644 --- a/request-server_test.go +++ b/request-server_test.go @@ -46,7 +46,11 @@ func clientRequestServerPair(t *testing.T) *csPair { fd, err := l.Accept() assert.Nil(t, err) handlers := InMemHandler() - server = NewRequestServer(fd, handlers) + var options []RequestServerOption + if *testAllocator { + options = append(options, WithRSAllocator()) + } + server = NewRequestServer(fd, handlers, options...) server.Serve() }() <-ready @@ -60,6 +64,15 @@ func clientRequestServerPair(t *testing.T) *csPair { return &csPair{client, server} } +func checkRequestServerAllocator(t *testing.T, p *csPair) { + if p.svr.pktMgr.alloc == nil { + return + } + checkAllocatorBeforeServerClose(t, p.svr.pktMgr.alloc) + p.Close() + checkAllocatorAfterServerClose(t, p.svr.pktMgr.alloc) +} + // after adding logging, maybe check log to make sure packet handling // was split over more than one worker func TestRequestSplitWrite(t *testing.T) { @@ -74,6 +87,7 @@ func TestRequestSplitWrite(t *testing.T) { r := p.testHandler() f, _ := r.fetch("/foo") assert.Equal(t, contents, string(f.content)) + checkRequestServerAllocator(t, p) } func TestRequestCache(t *testing.T) { @@ -101,6 +115,7 @@ func TestRequestCache(t *testing.T) { assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled") p.svr.closeRequest(bh) assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) } func TestRequestCacheState(t *testing.T) { @@ -114,6 +129,7 @@ func TestRequestCacheState(t *testing.T) { err = p.cli.Remove("/foo") assert.Nil(t, err) assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) } func putTestFile(cli *Client, path, content string) (int, error) { @@ -136,6 +152,7 @@ func TestRequestWrite(t *testing.T) { assert.Nil(t, err) assert.False(t, f.isdir) assert.Equal(t, f.content, []byte("hello")) + checkRequestServerAllocator(t, p) } func TestRequestWriteEmpty(t *testing.T) { @@ -156,6 +173,7 @@ func TestRequestWriteEmpty(t *testing.T) { assert.Error(t, err) r.returnErr(nil) assert.Equal(t, 0, n) + checkRequestServerAllocator(t, p) } func TestRequestFilename(t *testing.T) { @@ -169,6 +187,7 @@ func TestRequestFilename(t *testing.T) { assert.Equal(t, f.Name(), "foo") _, err = r.fetch("/bar") assert.Error(t, err) + checkRequestServerAllocator(t, p) } func TestRequestJustRead(t *testing.T) { @@ -186,6 +205,7 @@ func TestRequestJustRead(t *testing.T) { } assert.Equal(t, 5, n) assert.Equal(t, "hello", string(contents[0:5])) + checkRequestServerAllocator(t, p) } func TestRequestOpenFail(t *testing.T) { @@ -194,6 +214,7 @@ func TestRequestOpenFail(t *testing.T) { rf, err := p.cli.Open("/foo") assert.Exactly(t, os.ErrNotExist, err) assert.Nil(t, rf) + checkRequestServerAllocator(t, p) } func TestRequestCreate(t *testing.T) { @@ -203,6 +224,7 @@ func TestRequestCreate(t *testing.T) { assert.Nil(t, err) err = fh.Close() assert.Nil(t, err) + checkRequestServerAllocator(t, p) } func TestRequestMkdir(t *testing.T) { @@ -214,6 +236,7 @@ func TestRequestMkdir(t *testing.T) { f, err := r.fetch("/foo") assert.Nil(t, err) assert.True(t, f.isdir) + checkRequestServerAllocator(t, p) } func TestRequestRemove(t *testing.T) { @@ -228,6 +251,7 @@ func TestRequestRemove(t *testing.T) { assert.Nil(t, err) _, err = r.fetch("/foo") assert.Equal(t, err, os.ErrNotExist) + checkRequestServerAllocator(t, p) } func TestRequestRename(t *testing.T) { @@ -256,6 +280,7 @@ func TestRequestRename(t *testing.T) { assert.Equal(t, "baz", f.Name()) _, err = r.fetch("/bar") assert.Equal(t, os.ErrNotExist, err) + checkRequestServerAllocator(t, p) } func TestRequestRenameFail(t *testing.T) { @@ -267,6 +292,7 @@ func TestRequestRenameFail(t *testing.T) { assert.Nil(t, err) err = p.cli.Rename("/foo", "/bar") assert.IsType(t, &StatusError{}, err) + checkRequestServerAllocator(t, p) } func TestRequestStat(t *testing.T) { @@ -280,6 +306,7 @@ func TestRequestStat(t *testing.T) { assert.Equal(t, fi.Mode(), os.FileMode(0644)) assert.NoError(t, testOsSys(fi.Sys())) assert.NoError(t, err) + checkRequestServerAllocator(t, p) } // NOTE: Setstat is a noop in the request server tests, but we want to test @@ -298,6 +325,7 @@ func TestRequestSetstat(t *testing.T) { assert.Equal(t, fi.Size(), int64(5)) assert.Equal(t, fi.Mode(), os.FileMode(0644)) assert.NoError(t, testOsSys(fi.Sys())) + checkRequestServerAllocator(t, p) } func TestRequestFstat(t *testing.T) { @@ -314,6 +342,7 @@ func TestRequestFstat(t *testing.T) { assert.Equal(t, fi.Mode(), os.FileMode(0644)) assert.NoError(t, testOsSys(fi.Sys())) } + checkRequestServerAllocator(t, p) } func TestRequestStatFail(t *testing.T) { @@ -322,6 +351,7 @@ func TestRequestStatFail(t *testing.T) { fi, err := p.cli.Stat("/foo") assert.Nil(t, fi) assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) } func TestRequestLink(t *testing.T) { @@ -335,6 +365,7 @@ func TestRequestLink(t *testing.T) { fi, err := r.fetch("/bar") assert.Nil(t, err) assert.True(t, int(fi.Size()) == len("hello")) + checkRequestServerAllocator(t, p) } func TestRequestLinkFail(t *testing.T) { @@ -343,6 +374,7 @@ func TestRequestLinkFail(t *testing.T) { err := p.cli.Link("/foo", "/bar") t.Log(err) assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) } func TestRequestSymlink(t *testing.T) { @@ -356,6 +388,7 @@ func TestRequestSymlink(t *testing.T) { fi, err := r.fetch("/bar") assert.Nil(t, err) assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink) + checkRequestServerAllocator(t, p) } func TestRequestSymlinkFail(t *testing.T) { @@ -363,6 +396,7 @@ func TestRequestSymlinkFail(t *testing.T) { defer p.Close() err := p.cli.Symlink("/foo", "/bar") assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) } func TestRequestReadlink(t *testing.T) { @@ -375,6 +409,7 @@ func TestRequestReadlink(t *testing.T) { rl, err := p.cli.ReadLink("/bar") assert.Nil(t, err) assert.Equal(t, "foo", rl) + checkRequestServerAllocator(t, p) } func TestRequestReaddir(t *testing.T) { @@ -398,6 +433,7 @@ func TestRequestReaddir(t *testing.T) { assert.Len(t, di, 100) names := []string{di[18].Name(), di[81].Name()} assert.Equal(t, []string{"foo_18", "foo_81"}, names) + checkRequestServerAllocator(t, p) } func TestCleanPath(t *testing.T) { diff --git a/request.go b/request.go index c81bb784..772628bf 100644 --- a/request.go +++ b/request.go @@ -154,12 +154,12 @@ func (r *Request) close() error { } // called from worker to handle packet/request -func (r *Request) call(handlers Handlers, pkt requestPacket) responsePacket { +func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { switch r.Method { case "Get": - return fileget(handlers.FileGet, r, pkt) + return fileget(handlers.FileGet, r, pkt, alloc, orderID) case "Put": - return fileput(handlers.FilePut, r, pkt) + return fileput(handlers.FilePut, r, pkt, alloc, orderID) case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove": return filecmd(handlers.FileCmd, r, pkt) case "List": @@ -206,7 +206,7 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { } // wrap FileReader handler -func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket { +func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { //fmt.Println("fileget", r) r.state.RLock() reader := r.state.readerAt @@ -215,7 +215,7 @@ func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket { return statusFromError(pkt, errors.New("unexpected read packet")) } - data, offset, _ := packetData(pkt) + data, offset, _ := packetData(pkt, alloc, orderID) n, err := reader.ReadAt(data, offset) // only return EOF erro if no data left to read if err != nil && (err != io.EOF || n == 0) { @@ -229,7 +229,7 @@ func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket { } // wrap FileWriter handler -func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket { +func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { //fmt.Println("fileput", r) r.state.RLock() writer := r.state.writerAt @@ -238,18 +238,18 @@ func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket { return statusFromError(pkt, errors.New("unexpected write packet")) } - data, offset, _ := packetData(pkt) + data, offset, _ := packetData(pkt, alloc, orderID) _, err := writer.WriteAt(data, offset) return statusFromError(pkt, err) } // file data for additional read/write packets -func packetData(p requestPacket) (data []byte, offset int64, length uint32) { +func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) { switch p := p.(type) { case *sshFxpReadPacket: length = p.Len offset = int64(p.Offset) - data = p.getDataSlice() + data = p.getDataSlice(alloc, orderID) case *sshFxpWritePacket: data = p.Data length = p.Length diff --git a/request_test.go b/request_test.go index ea14168e..9f1ed661 100644 --- a/request_test.go +++ b/request_test.go @@ -152,7 +152,7 @@ func TestRequestGet(t *testing.T) { for i, txt := range []string{"file-", "data."} { pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a", Offset: uint64(i * 5), Len: 5} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) dpkt := rpkt.(*sshFxpDataPacket) assert.Equal(t, dpkt.id(), uint32(i)) assert.Equal(t, string(dpkt.Data), txt) @@ -165,7 +165,7 @@ func TestRequestCustomError(t *testing.T) { pkt := fakePacket{myid: 1} cmdErr := errors.New("stat not supported") handlers.returnError(cmdErr) - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) assert.Equal(t, rpkt, statusFromError(rpkt, cmdErr)) } @@ -176,11 +176,11 @@ func TestRequestPut(t *testing.T) { request.state.writerAt, _ = handlers.FilePut.Filewrite(request) pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5, Data: []byte("file-")} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) checkOkStatus(t, rpkt) pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5, Data: []byte("data.")} - rpkt = request.call(handlers, pkt) + rpkt = request.call(handlers, pkt, nil, 0) checkOkStatus(t, rpkt) assert.Equal(t, "file-data.", handlers.getOutString()) } @@ -189,11 +189,11 @@ func TestRequestCmdr(t *testing.T) { handlers := newTestHandlers() request := testRequest("Mkdir") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) checkOkStatus(t, rpkt) handlers.returnError(errTest) - rpkt = request.call(handlers, pkt) + rpkt = request.call(handlers, pkt, nil, 0) assert.Equal(t, rpkt, statusFromError(rpkt, errTest)) } @@ -201,7 +201,7 @@ func TestRequestInfoStat(t *testing.T) { handlers := newTestHandlers() request := testRequest("Stat") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) spkt, ok := rpkt.(*sshFxpStatResponse) assert.True(t, ok) assert.Equal(t, spkt.info.Name(), "request_test.go") @@ -218,13 +218,13 @@ func TestRequestInfoList(t *testing.T) { assert.Equal(t, hpkt.Handle, "1") } pkt = fakePacket{myid: 2} - request.call(handlers, pkt) + request.call(handlers, pkt, nil, 0) } func TestRequestInfoReadlink(t *testing.T) { handlers := newTestHandlers() request := testRequest("Readlink") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) npkt, ok := rpkt.(*sshFxpNamePacket) if assert.True(t, ok) { assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0]) @@ -237,7 +237,7 @@ func TestOpendirHandleReuse(t *testing.T) { request := testRequest("Stat") request.handle = "1" pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) assert.IsType(t, &sshFxpStatResponse{}, rpkt) request.Method = "List" @@ -247,6 +247,6 @@ func TestOpendirHandleReuse(t *testing.T) { hpkt := rpkt.(*sshFxpHandlePacket) assert.Equal(t, hpkt.Handle, "1") } - rpkt = request.call(handlers, pkt) + rpkt = request.call(handlers, pkt, nil, 0) assert.IsType(t, &sshFxpNamePacket{}, rpkt) } diff --git a/server.go b/server.go index 7802e9b6..013350cc 100644 --- a/server.go +++ b/server.go @@ -116,6 +116,19 @@ func ReadOnly() ServerOption { } } +// WithAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithAllocator() ServerOption { + return func(s *Server) error { + alloc := newAllocator() + s.pktMgr.alloc = alloc + s.conn.alloc = alloc + return nil + } +} + type rxPacket struct { pktType fxp pktBytes []byte @@ -138,9 +151,9 @@ func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { // If server is operating read-only and a write operation is requested, // return permission denied if !readonly && svr.readOnly { - svr.sendPacket(orderedResponse{ - responsePacket: statusFromError(pkt, syscall.EPERM), - orderid: pkt.orderID()}) + svr.pktMgr.readyPacket( + svr.pktMgr.newOrderedResponse(statusFromError(pkt, syscall.EPERM), pkt.orderID()), + ) continue } @@ -153,6 +166,7 @@ func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { func handlePacket(s *Server, p orderedRequest) error { var rpkt responsePacket + orderID := p.orderID() switch p := p.requestPacket.(type) { case *sshFxInitPacket: rpkt = sshFxVersionPacket{ @@ -256,7 +270,7 @@ func handlePacket(s *Server, p orderedRequest) error { f, ok := s.getHandle(p.Handle) if ok { err = nil - data := p.getDataSlice() + data := p.getDataSlice(s.pktMgr.alloc, orderID) n, _err := f.ReadAt(data, int64(p.Offset)) if _err != nil && (_err != io.EOF || n == 0) { err = _err @@ -291,13 +305,18 @@ func handlePacket(s *Server, p orderedRequest) error { return errors.Errorf("unexpected packet type %T", p) } - s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, p.orderID())) + s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, orderID)) return nil } // Serve serves SFTP connections until the streams stop or the SFTP subsystem // is stopped. func (svr *Server) Serve() error { + defer func() { + if svr.pktMgr.alloc != nil { + svr.pktMgr.alloc.Free() + } + }() var wg sync.WaitGroup runWorker := func(ch chan orderedRequest) { wg.Add(1) @@ -315,8 +334,9 @@ func (svr *Server) Serve() error { var pktType uint8 var pktBytes []byte for { - pktType, pktBytes, err = svr.recvPacket() + pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID()) if err != nil { + // we don't care about releasing allocated pages here, the server will quit and the allocator freed break } diff --git a/server_integration_test.go b/server_integration_test.go index f15a6e09..0ad87e02 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -29,6 +29,7 @@ import ( "time" "github.com/kr/fs" + "github.com/stretchr/testify/assert" "golang.org/x/crypto/ssh" ) @@ -64,6 +65,7 @@ func skipIfWindows(t testing.TB) { var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") +var testAllocator = flag.Bool("allocator", false, "perform tests using the allocator") var testSftp *string var testSftpClientBin *string @@ -468,6 +470,38 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i return stdout.String(), err } +// assert.Eventually seems to have a data rate on macOS with go 1.14 so replace it with this simpler function +func waitForCondition(t *testing.T, condition func() bool) { + start := time.Now() + tick := 10 * time.Millisecond + waitFor := 100 * time.Millisecond + for !condition() { + time.Sleep(tick) + if time.Since(start) > waitFor { + break + } + } + assert.True(t, condition()) +} + +func checkAllocatorBeforeServerClose(t *testing.T, alloc *allocator) { + if alloc != nil { + // before closing the server we are, generally, waiting for new packets in recvPacket and we have a page allocated. + // Sometime the sendPacket returns some milliseconds after the client receives the response, and so we have 2 + // allocated pages here, so wait some milliseconds. To avoid crashes we must be sure to not release the pages + // too soon. + waitForCondition(t, func() bool { return alloc.countUsedPages() <= 1 }) + } +} + +func checkAllocatorAfterServerClose(t *testing.T, alloc *allocator) { + if alloc != nil { + // wait for the server cleanup + waitForCondition(t, func() bool { return alloc.countUsedPages() == 0 }) + waitForCondition(t, func() bool { return alloc.countAvailablePages() == 0 }) + } +} + func TestServerCompareSubsystems(t *testing.T) { listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) listenerOp, hostOp, portOp := testServer(t, OpenSSHSFTP, READONLY) diff --git a/server_test.go b/server_test.go index 5191415f..6995af95 100644 --- a/server_test.go +++ b/server_test.go @@ -162,10 +162,14 @@ func runLsTestHelper(t *testing.T, result, expectedType, path string) { func clientServerPair(t *testing.T) (*Client, *Server) { cr, sw := io.Pipe() sr, cw := io.Pipe() + var options []ServerOption + if *testAllocator { + options = append(options, WithAllocator()) + } server, err := NewServer(struct { io.Reader io.WriteCloser - }{sr, sw}) + }{sr, sw}, options...) if err != nil { t.Fatal(err) } @@ -198,6 +202,15 @@ func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) { return b, nil } +func checkServerAllocator(t *testing.T, server *Server) { + if server.pktMgr.alloc == nil { + return + } + checkAllocatorBeforeServerClose(t, server.pktMgr.alloc) + server.Close() + checkAllocatorAfterServerClose(t, server.pktMgr.alloc) +} + // test that errors are sent back when we request an invalid extended packet operation // this validates the following rfc draft is followed https://tools.ietf.org/html/draft-ietf-secsh-filexfer-extensions-00 func TestInvalidExtendedPacket(t *testing.T) { @@ -222,6 +235,7 @@ func TestInvalidExtendedPacket(t *testing.T) { if statusErr.Code != sshFxOPUnsupported { t.Errorf("statusErr.Code => %d, wanted %d", statusErr.Code, sshFxOPUnsupported) } + checkServerAllocator(t, server) } // test that server handles concurrent requests correctly @@ -251,6 +265,7 @@ func TestConcurrentRequests(t *testing.T) { }() } wg.Wait() + checkServerAllocator(t, server) } // Test error conversion @@ -327,4 +342,5 @@ func TestOpenStatRace(t *testing.T) { testreply(id1, ch) testreply(id2, ch) os.Remove(tmppath) + checkServerAllocator(t, server) }