diff --git a/request.go b/request.go index 6f02213e..d842787a 100644 --- a/request.go +++ b/request.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "os" - "path" - "path/filepath" "strings" "sync" "syscall" @@ -16,6 +14,113 @@ import ( // MaxFilelist is the max number of files to return in a readdir batch. var MaxFilelist int64 = 100 +// state encapsulates the reader/writer/readdir from handlers. +type state struct { + mu sync.RWMutex + + writerAt io.WriterAt + readerAt io.ReaderAt + writerAtReaderAt WriterAtReaderAt + listerAt ListerAt + lsoffset int64 +} + +// copy returns a shallow copy the state. +// This is broken out to specific fields, +// because we have to copy around the mutex in state. +func (s *state) copy() state { + s.mu.RLock() + defer s.mu.RUnlock() + + return state{ + writerAt: s.writerAt, + readerAt: s.readerAt, + writerAtReaderAt: s.writerAtReaderAt, + listerAt: s.listerAt, + lsoffset: s.lsoffset, + } +} + +func (s *state) setReaderAt(rd io.ReaderAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.readerAt = rd +} + +func (s *state) getReaderAt() io.ReaderAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.readerAt +} + +func (s *state) setWriterAt(rd io.WriterAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.writerAt = rd +} + +func (s *state) getWriterAt() io.WriterAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.writerAt +} + +func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.writerAtReaderAt = rw +} + +func (s *state) getWriterAtReaderAt() WriterAtReaderAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.writerAtReaderAt +} + +func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.readerAt, s.writerAt, s.writerAtReaderAt +} + +// Returns current offset for file list +func (s *state) lsNext() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.lsoffset +} + +// Increases next offset +func (s *state) lsInc(offset int64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.lsoffset += offset +} + +// manage file read/write state +func (s *state) setListerAt(la ListerAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.listerAt = la +} + +func (s *state) getListerAt() ListerAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.listerAt +} + // Request contains the data and state for the incoming service request. type Request struct { // Get, Put, Setstat, Stat, Rename, Remove @@ -26,20 +131,40 @@ type Request struct { Attrs []byte // convert to sub-struct Target string // for renames and sym-links handle string + // reader/writer/readdir from handlers - state state + state + // context lasts duration of request ctx context.Context cancelCtx context.CancelFunc } -type state struct { - *sync.RWMutex - writerAt io.WriterAt - readerAt io.ReaderAt - writerReaderAt WriterAtReaderAt - listerAt ListerAt - lsoffset int64 +// NewRequest creates a new Request object. +func NewRequest(method, path string) *Request { + return &Request{ + Method: method, + Filepath: cleanPath(path), + } +} + +// copy returns a shallow copy of existing request. +// This is broken out to specific fields, +// because we have to copy around the mutex in state. +func (r *Request) copy() *Request { + return &Request{ + Method: r.Method, + Filepath: r.Filepath, + Flags: r.Flags, + Attrs: r.Attrs, + Target: r.Target, + handle: r.handle, + + state: r.state.copy(), + + ctx: r.ctx, + cancelCtx: r.cancelCtx, + } } // New Request initialized based on packet data @@ -66,21 +191,6 @@ func requestFromPacket(ctx context.Context, pkt hasPath) *Request { return request } -// NewRequest creates a new Request object. -func NewRequest(method, path string) *Request { - return &Request{Method: method, Filepath: cleanPath(path), - state: state{RWMutex: new(sync.RWMutex)}} -} - -// shallow copy of existing request -func (r *Request) copy() *Request { - r.state.Lock() - defer r.state.Unlock() - r2 := new(Request) - *r2 = *r - return r2 -} - // Context returns the request's context. To change the context, // use WithContext. // @@ -108,33 +218,6 @@ func (r *Request) WithContext(ctx context.Context) *Request { return r2 } -// Returns current offset for file list -func (r *Request) lsNext() int64 { - r.state.RLock() - defer r.state.RUnlock() - return r.state.lsoffset -} - -// Increases next offset -func (r *Request) lsInc(offset int64) { - r.state.Lock() - defer r.state.Unlock() - r.state.lsoffset = r.state.lsoffset + offset -} - -// manage file read/write state -func (r *Request) setListerState(la ListerAt) { - r.state.Lock() - defer r.state.Unlock() - r.state.listerAt = la -} - -func (r *Request) getLister() ListerAt { - r.state.RLock() - defer r.state.RUnlock() - return r.state.listerAt -} - // Close reader/writer if possible func (r *Request) close() error { defer func() { @@ -143,11 +226,7 @@ func (r *Request) close() error { } }() - r.state.RLock() - wr := r.state.writerAt - rd := r.state.readerAt - rw := r.state.writerReaderAt - r.state.RUnlock() + rd, wr, rw := r.getAllReaderWriters() var err error @@ -164,7 +243,8 @@ func (r *Request) close() error { if err2 := c.Close(); err == nil { // update error if it is still nil err = err2 - r.state.writerReaderAt = nil + + r.setWriterAtReaderAt(nil) } } @@ -184,11 +264,7 @@ func (r *Request) transferError(err error) { return } - r.state.RLock() - wr := r.state.writerAt - rd := r.state.readerAt - rw := r.state.writerReaderAt - r.state.RUnlock() + rd, wr, rw := r.getAllReaderWriters() if t, ok := wr.(TransferError); ok { t.TransferError(err) @@ -219,8 +295,7 @@ func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, o case "Stat", "Lstat", "Readlink": return filestat(handlers.FileList, r, pkt) default: - return statusFromError(pkt.id(), - fmt.Errorf("unexpected method: %s", r.Method)) + return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method)) } } @@ -239,8 +314,13 @@ func (r *Request) open(h Handlers, pkt requestPacket) responsePacket { if err != nil { return statusFromError(id, err) } - r.state.writerReaderAt = rw - return &sshFxpHandlePacket{ID: id, Handle: r.handle} + + r.setWriterAtReaderAt(rw) + + return &sshFxpHandlePacket{ + ID: id, + Handle: r.handle, + } } } @@ -249,18 +329,26 @@ func (r *Request) open(h Handlers, pkt requestPacket) responsePacket { if err != nil { return statusFromError(id, err) } - r.state.writerAt = wr + + r.setWriterAt(wr) + case flags.Read: r.Method = "Get" rd, err := h.FileGet.Fileread(r) if err != nil { return statusFromError(id, err) } - r.state.readerAt = rd + + r.setReaderAt(rd) + default: return statusFromError(id, errors.New("bad file flags")) } - return &sshFxpHandlePacket{ID: id, Handle: r.handle} + + return &sshFxpHandlePacket{ + ID: id, + Handle: r.handle, + } } func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { @@ -269,25 +357,30 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { if err != nil { return statusFromError(pkt.id(), wrapPathError(r.Filepath, err)) } - r.state.listerAt = la - return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle} + + r.setListerAt(la) + + return &sshFxpHandlePacket{ + ID: pkt.id(), + Handle: r.handle, + } } // wrap FileReader handler func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { - r.state.RLock() - reader := r.state.readerAt - r.state.RUnlock() - if reader == nil { + rd := r.getReaderAt() + if rd == nil { return statusFromError(pkt.id(), errors.New("unexpected read packet")) } data, offset, _ := packetData(pkt, alloc, orderID) - n, err := reader.ReadAt(data, offset) + + n, err := rd.ReadAt(data, offset) // only return EOF error if no data left to read if err != nil && (err != io.EOF || n == 0) { return statusFromError(pkt.id(), err) } + return &sshFxpDataPacket{ ID: pkt.id(), Length: uint32(n), @@ -297,43 +390,46 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde // wrap FileWriter handler func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { - r.state.RLock() - writer := r.state.writerAt - r.state.RUnlock() - if writer == nil { + wr := r.getWriterAt() + if wr == nil { return statusFromError(pkt.id(), errors.New("unexpected write packet")) } data, offset, _ := packetData(pkt, alloc, orderID) - _, err := writer.WriteAt(data, offset) + + _, err := wr.WriteAt(data, offset) return statusFromError(pkt.id(), err) } // wrap OpenFileWriter handler func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { - r.state.RLock() - writerReader := r.state.writerReaderAt - r.state.RUnlock() - if writerReader == nil { + rw := r.getWriterAtReaderAt() + if rw == nil { return statusFromError(pkt.id(), errors.New("unexpected write and read packet")) } + switch p := pkt.(type) { case *sshFxpReadPacket: data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset) - n, err := writerReader.ReadAt(data, offset) + + n, err := rw.ReadAt(data, offset) // only return EOF error if no data left to read if err != nil && (err != io.EOF || n == 0) { return statusFromError(pkt.id(), err) } + return &sshFxpDataPacket{ ID: pkt.id(), Length: uint32(n), Data: data[:n], } + case *sshFxpWritePacket: data, offset := p.Data, int64(p.Offset) - _, err := writerReader.WriteAt(data, offset) + + _, err := rw.WriteAt(data, offset) return statusFromError(pkt.id(), err) + default: return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write")) } @@ -358,7 +454,8 @@ func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket { r.Attrs = p.Attrs.([]byte) } - if r.Method == "PosixRename" { + switch r.Method { + case "PosixRename": if posixRenamer, ok := h.(PosixRenameFileCmder); ok { err := posixRenamer.PosixRename(r) return statusFromError(pkt.id(), err) @@ -368,9 +465,8 @@ func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket { r.Method = "Rename" err := h.Filecmd(r) return statusFromError(pkt.id(), err) - } - if r.Method == "StatVFS" { + case "StatVFS": if statVFSCmdr, ok := h.(StatVFSFileCmder); ok { stat, err := statVFSCmdr.StatVFS(r) if err != nil { @@ -389,8 +485,7 @@ func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket { // wrap FileLister handler func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { - var err error - lister := r.getLister() + lister := r.getListerAt() if lister == nil { return statusFromError(pkt.id(), errors.New("unexpected dir packet")) } @@ -404,23 +499,25 @@ func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { switch r.Method { case "List": - if err != nil && err != io.EOF { + if err != nil && (err != io.EOF || n == 0) { return statusFromError(pkt.id(), err) } - if err == io.EOF && n == 0 { - return statusFromError(pkt.id(), io.EOF) - } - dirname := filepath.ToSlash(path.Base(r.Filepath)) - ret := &sshFxpNamePacket{ID: pkt.id()} + + nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo)) for _, fi := range finfo { - ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{ + nameAttrs = append(nameAttrs, &sshFxpNameAttr{ Name: fi.Name(), - LongName: runLs(dirname, fi), + LongName: runLs(fi), Attrs: []interface{}{fi}, }) } - return ret + + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: nameAttrs, + } + default: err = fmt.Errorf("unexpected method: %s", r.Method) return statusFromError(pkt.id(), err) @@ -455,8 +552,11 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { return statusFromError(pkt.id(), err) } if n == 0 { - err = &os.PathError{Op: strings.ToLower(r.Method), Path: r.Filepath, - Err: syscall.ENOENT} + err = &os.PathError{ + Op: strings.ToLower(r.Method), + Path: r.Filepath, + Err: syscall.ENOENT, + } return statusFromError(pkt.id(), err) } return &sshFxpStatResponse{ @@ -468,8 +568,11 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { return statusFromError(pkt.id(), err) } if n == 0 { - err = &os.PathError{Op: "readlink", Path: r.Filepath, - Err: syscall.ENOENT} + err = &os.PathError{ + Op: "readlink", + Path: r.Filepath, + Err: syscall.ENOENT, + } return statusFromError(pkt.id(), err) } filename := finfo[0].Name() diff --git a/request_test.go b/request_test.go index d3f7db06..92f7c2bf 100644 --- a/request_test.go +++ b/request_test.go @@ -1,15 +1,13 @@ package sftp import ( - "sync" - - "github.com/stretchr/testify/assert" - "bytes" "errors" "io" "os" "testing" + + "github.com/stretchr/testify/assert" ) type testHandler struct { @@ -75,7 +73,6 @@ func testRequest(method string) *Request { Attrs: []byte("foo"), Flags: flags, Target: "foo", - state: state{RWMutex: new(sync.RWMutex)}, } return request } diff --git a/server.go b/server.go index e3599d88..a38d09bb 100644 --- a/server.go +++ b/server.go @@ -461,7 +461,6 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { return statusFromError(p.ID, EBADF) } - dirname := f.Name() dirents, err := f.Readdir(128) if err != nil { return statusFromError(p.ID, err) @@ -471,7 +470,7 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { for _, dirent := range dirents { ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{ Name: dirent.Name(), - LongName: runLs(dirname, dirent), + LongName: runLs(dirent), Attrs: []interface{}{dirent}, }) } diff --git a/server_stubs.go b/server_stubs.go index 62c9fa1a..84a7d7fc 100644 --- a/server_stubs.go +++ b/server_stubs.go @@ -8,7 +8,7 @@ import ( "time" ) -func runLs(dirname string, dirent os.FileInfo) string { +func runLs(dirent os.FileInfo) string { typeword := runLsTypeWord(dirent) numLinks := 1 if dirent.IsDir() { diff --git a/server_test.go b/server_test.go index 5a01c495..74da0c6b 100644 --- a/server_test.go +++ b/server_test.go @@ -25,14 +25,14 @@ const ( func TestRunLsWithExamplesDirectory(t *testing.T) { path := "examples" item, _ := os.Stat(path) - result := runLs(path, item) + result := runLs(item) runLsTestHelper(t, result, typeDirectory, path) } func TestRunLsWithLicensesFile(t *testing.T) { path := "LICENSE" item, _ := os.Stat(path) - result := runLs(path, item) + result := runLs(item) runLsTestHelper(t, result, typeFile, path) } @@ -79,61 +79,61 @@ func runLsTestHelper(t *testing.T, result, expectedType, path string) { // permissions (len 10, "drwxr-xr-x") got := result[0:10] if ok, err := regexp.MatchString("^"+expectedType+"[rwx-]{9}$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): permission field mismatch, expected dir, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): permission field mismatch, expected dir, got: %#v, err: %#v", got, err) } // space got = result[10:11] if ok, err := regexp.MatchString("^\\s$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): spacer 1 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): spacer 1 mismatch, expected whitespace, got: %#v, err: %#v", got, err) } // link count (len 3, number) got = result[12:15] if ok, err := regexp.MatchString("^\\s*[0-9]+$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): link count field mismatch, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): link count field mismatch, got: %#v, err: %#v", got, err) } // spacer got = result[15:16] if ok, err := regexp.MatchString("^\\s$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): spacer 2 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): spacer 2 mismatch, expected whitespace, got: %#v, err: %#v", got, err) } // username / uid (len 8, number or string) got = result[16:24] if ok, err := regexp.MatchString("^[^\\s]{1,8}\\s*$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): username / uid mismatch, expected user, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): username / uid mismatch, expected user, got: %#v, err: %#v", got, err) } // spacer got = result[24:25] if ok, err := regexp.MatchString("^\\s$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): spacer 3 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): spacer 3 mismatch, expected whitespace, got: %#v, err: %#v", got, err) } // groupname / gid (len 8, number or string) got = result[25:33] if ok, err := regexp.MatchString("^[^\\s]{1,8}\\s*$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): groupname / gid mismatch, expected group, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): groupname / gid mismatch, expected group, got: %#v, err: %#v", got, err) } // spacer got = result[33:34] if ok, err := regexp.MatchString("^\\s$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): spacer 4 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): spacer 4 mismatch, expected whitespace, got: %#v, err: %#v", got, err) } // filesize (len 8) got = result[34:42] if ok, err := regexp.MatchString("^\\s*[0-9]+$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): filesize field mismatch, expected size in bytes, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): filesize field mismatch, expected size in bytes, got: %#v, err: %#v", got, err) } // spacer got = result[42:43] if ok, err := regexp.MatchString("^\\s$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): spacer 5 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): spacer 5 mismatch, expected whitespace, got: %#v, err: %#v", got, err) } // mod time (len 12, e.g. Aug 9 19:46) @@ -146,19 +146,19 @@ func runLsTestHelper(t *testing.T, result, expectedType, path string) { _, err = time.Parse(layout, got) } if err != nil { - t.Errorf("runLs(%#v, *FileInfo): mod time field mismatch, expected date layout %s, got: %#v, err: %#v", path, layout, got, err) + t.Errorf("runLs(*FileInfo): mod time field mismatch, expected date layout %s, got: %#v, err: %#v", layout, got, err) } // spacer got = result[55:56] if ok, err := regexp.MatchString("^\\s$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): spacer 6 mismatch, expected whitespace, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): spacer 6 mismatch, expected whitespace, got: %#v, err: %#v", got, err) } // filename got = result[56:] if ok, err := regexp.MatchString("^"+path+"$", got); !ok { - t.Errorf("runLs(%#v, *FileInfo): name field mismatch, expected examples, got: %#v, err: %#v", path, got, err) + t.Errorf("runLs(*FileInfo): name field mismatch, expected examples, got: %#v, err: %#v", got, err) } } diff --git a/server_unix.go b/server_unix.go index a7b7617c..9ba54b60 100644 --- a/server_unix.go +++ b/server_unix.go @@ -12,7 +12,7 @@ import ( // ls -l style output for a file, which is in the 'long output' section of a readdir response packet // this is a very simple (lazy) implementation, just enough to look almost like openssh in a few basic cases -func runLs(dirname string, dirent os.FileInfo) string { +func runLs(dirent os.FileInfo) string { // example from openssh sftp server: // crw-rw-rw- 1 root wheel 0 Jul 31 20:52 ttyvd // format: