diff --git a/request-server.go b/request-server.go index cb357e3b..e9974d52 100644 --- a/request-server.go +++ b/request-server.go @@ -106,37 +106,21 @@ func (rs *RequestServer) closeRequest(handle string) error { // Close the read/write/closer to trigger exiting the main server loop func (rs *RequestServer) Close() error { return rs.conn.Close() } -// Serve requests for user session -func (rs *RequestServer) Serve() error { - defer func() { - if rs.pktMgr.alloc != nil { - rs.pktMgr.alloc.Free() - } - }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - var wg sync.WaitGroup - runWorker := func(ch chan orderedRequest) { - wg.Add(1) - go func() { - defer wg.Done() - if err := rs.packetWorker(ctx, ch); err != nil { - rs.conn.Close() // shuts down recvPacket - } - }() - } - pktChan := rs.pktMgr.workerChan(runWorker) +func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error { + defer close(pktChan) // shuts down sftpServerWorkers var err error var pkt requestPacket var pktType uint8 var pktBytes []byte + for { pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID()) if err != nil { // we don't care about releasing allocated pages here, the server will quit and the allocator freed - break + return err } + pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) if err != nil { switch errors.Cause(err) { @@ -145,33 +129,47 @@ func (rs *RequestServer) Serve() error { default: debug("makePacket err: %v", err) rs.conn.Close() // shuts down recvPacket - break + return err } } pktChan <- rs.pktMgr.newOrderedRequest(pkt) } +} - close(pktChan) // shuts down sftpServerWorkers - wg.Wait() // wait for all workers to exit +// Serve requests for user session +func (rs *RequestServer) Serve() error { + defer func() { + if rs.pktMgr.alloc != nil { + rs.pktMgr.alloc.Free() + } + }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var wg sync.WaitGroup + runWorker := func(ch chan orderedRequest) { + wg.Add(1) + go func() { + defer wg.Done() + if err := rs.packetWorker(ctx, ch); err != nil { + rs.conn.Close() // shuts down recvPacket + } + }() + } + pktChan := rs.pktMgr.workerChan(runWorker) + + err := rs.serveLoop(pktChan) + + wg.Wait() // wait for all workers to exit + + rs.openRequestLock.Lock() + defer rs.openRequestLock.Unlock() // make sure all open requests are properly closed // (eg. possible on dropped connections, client crashes, etc.) for handle, req := range rs.openRequests { - if err != nil { - req.state.RLock() - writer := req.state.writerAt - reader := req.state.readerAt - req.state.RUnlock() - if t, ok := writer.(TransferError); ok { - debug("notify error: %v to writer: %v\n", err, writer) - t.TransferError(err) - } - if t, ok := reader.(TransferError); ok { - debug("notify error: %v to reader: %v\n", err, reader) - t.TransferError(err) - } - } + req.transferError(err) + delete(rs.openRequests, handle) req.close() } diff --git a/request.go b/request.go index 772628bf..df2d0379 100644 --- a/request.go +++ b/request.go @@ -138,19 +138,51 @@ func (r *Request) close() error { r.cancelCtx() } }() + r.state.RLock() + wr := r.state.writerAt rd := r.state.readerAt r.state.RUnlock() + + var err error + + // Close errors on a Writer are far more likely to be the important one. + // As they can be information that there was a loss of data. + if c, ok := wr.(io.Closer); ok { + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + } + } + if c, ok := rd.(io.Closer); ok { - return c.Close() + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + } + } + + return err +} + +// Close reader/writer if possible +func (r *Request) transferError(err error) { + if err == nil { + return } + r.state.RLock() - wt := r.state.writerAt + wr := r.state.writerAt + rd := r.state.readerAt r.state.RUnlock() - if c, ok := wt.(io.Closer); ok { - return c.Close() + + if t, ok := wr.(TransferError); ok { + t.TransferError(err) + } + + if t, ok := rd.(TransferError); ok { + t.TransferError(err) } - return nil } // called from worker to handle packet/request