diff --git a/packet.go b/packet.go index 1a1a87d7..7fd605c6 100644 --- a/packet.go +++ b/packet.go @@ -1247,7 +1247,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { } func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { - err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) + err := os.Rename(toLocalPath(s.workDir, p.Oldpath), toLocalPath(s.workDir, p.Newpath)) return statusFromError(p.ID, err) } @@ -1276,6 +1276,6 @@ func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error { } func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket { - err := os.Link(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) + err := os.Link(toLocalPath(s.workDir, p.Oldpath), toLocalPath(s.workDir, p.Newpath)) return statusFromError(p.ID, err) } diff --git a/request-plan9.go b/request-plan9.go index 2444da59..daffd3e1 100644 --- a/request-plan9.go +++ b/request-plan9.go @@ -16,7 +16,15 @@ func testOsSys(sys interface{}) error { return nil } -func toLocalPath(p string) string { +func toLocalPath(workDir, p string) string { + if workDir != "" { + if !filepath.IsAbs(p) && !path.IsAbs(p) { + // Ensure input is always in the same format. + p = filepath.ToSlash(p) + p = path.Join(workDir, p) + } + } + lp := filepath.FromSlash(p) if path.IsAbs(p) { @@ -28,6 +36,7 @@ func toLocalPath(p string) string { // e.g. "/#s/boot" to "#s/boot" return tmp } + } return lp diff --git a/request-unix.go b/request-unix.go index 50b08a38..2273d567 100644 --- a/request-unix.go +++ b/request-unix.go @@ -4,6 +4,8 @@ package sftp import ( "errors" + "path" + "path/filepath" "syscall" ) @@ -22,6 +24,14 @@ func testOsSys(sys interface{}) error { return nil } -func toLocalPath(p string) string { +func toLocalPath(workDir, p string) string { + if workDir != "" { + if !filepath.IsAbs(p) && !path.IsAbs(p) { + // Ensure input is always in the same format. + p = filepath.ToSlash(p) + p = path.Join(workDir, p) + } + } + return p } diff --git a/request_windows.go b/request_windows.go index 1f6d3df1..d53cecf2 100644 --- a/request_windows.go +++ b/request_windows.go @@ -14,7 +14,15 @@ func testOsSys(sys interface{}) error { return nil } -func toLocalPath(p string) string { +func toLocalPath(workDir, p string) string { + if workDir != "" { + if !filepath.IsAbs(p) && !path.IsAbs(p) { + // Ensure input is always in the same format. + p = filepath.ToSlash(p) + p = path.Join(workDir, p) + } + } + lp := filepath.FromSlash(p) if path.IsAbs(p) { @@ -38,6 +46,7 @@ func toLocalPath(p string) string { // e.g. "/C:" to "C:\\" return tmp } + } return lp diff --git a/server.go b/server.go index 529052b4..677a10c7 100644 --- a/server.go +++ b/server.go @@ -33,6 +33,7 @@ type Server struct { openFiles map[string]*os.File openFilesLock sync.RWMutex handleCount int + workDir string } func (svr *Server) nextHandle(f *os.File) string { @@ -128,6 +129,16 @@ func WithAllocator() ServerOption { } } +// WithServerWorkingDirectory sets a working directory to use as base +// for relative paths. +// If unset the default is current working directory (os.Getwd). +func WithServerWorkingDirectory(workDir string) ServerOption { + return func(s *Server) error { + s.workDir = cleanPath(workDir) + return nil + } +} + type rxPacket struct { pktType fxp pktBytes []byte @@ -174,7 +185,7 @@ func handlePacket(s *Server, p orderedRequest) error { } case *sshFxpStatPacket: // stat the requested file - info, err := os.Stat(toLocalPath(p.Path)) + info, err := os.Stat(toLocalPath(s.workDir, p.Path)) rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, @@ -184,7 +195,7 @@ func handlePacket(s *Server, p orderedRequest) error { } case *sshFxpLstatPacket: // stat the requested file - info, err := os.Lstat(toLocalPath(p.Path)) + info, err := os.Lstat(toLocalPath(s.workDir, p.Path)) rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, @@ -208,24 +219,24 @@ func handlePacket(s *Server, p orderedRequest) error { } case *sshFxpMkdirPacket: // TODO FIXME: ignore flags field - err := os.Mkdir(toLocalPath(p.Path), 0755) + err := os.Mkdir(toLocalPath(s.workDir, p.Path), 0o755) rpkt = statusFromError(p.ID, err) case *sshFxpRmdirPacket: - err := os.Remove(toLocalPath(p.Path)) + err := os.Remove(toLocalPath(s.workDir, p.Path)) rpkt = statusFromError(p.ID, err) case *sshFxpRemovePacket: - err := os.Remove(toLocalPath(p.Filename)) + err := os.Remove(toLocalPath(s.workDir, p.Filename)) rpkt = statusFromError(p.ID, err) case *sshFxpRenamePacket: - err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) + err := os.Rename(toLocalPath(s.workDir, p.Oldpath), toLocalPath(s.workDir, p.Newpath)) rpkt = statusFromError(p.ID, err) case *sshFxpSymlinkPacket: - err := os.Symlink(toLocalPath(p.Targetpath), toLocalPath(p.Linkpath)) + err := os.Symlink(toLocalPath(s.workDir, p.Targetpath), toLocalPath(s.workDir, p.Linkpath)) rpkt = statusFromError(p.ID, err) case *sshFxpClosePacket: rpkt = statusFromError(p.ID, s.closeHandle(p.Handle)) case *sshFxpReadlinkPacket: - f, err := os.Readlink(toLocalPath(p.Path)) + f, err := os.Readlink(toLocalPath(s.workDir, p.Path)) rpkt = &sshFxpNamePacket{ ID: p.ID, NameAttrs: []*sshFxpNameAttr{ @@ -240,29 +251,21 @@ func handlePacket(s *Server, p orderedRequest) error { rpkt = statusFromError(p.ID, err) } case *sshFxpRealpathPacket: - f, err := filepath.Abs(toLocalPath(p.Path)) + f, err := filepath.Abs(toLocalPath(s.workDir, p.Path)) f = cleanPath(f) - rpkt = &sshFxpNamePacket{ - ID: p.ID, - NameAttrs: []*sshFxpNameAttr{ - { - Name: f, - LongName: f, - Attrs: emptyFileStat, - }, - }, - } + rpkt = cleanPacketPath(p, f) if err != nil { rpkt = statusFromError(p.ID, err) } case *sshFxpOpendirPacket: - p.Path = toLocalPath(p.Path) + p.Path = toLocalPath(s.workDir, p.Path) if stat, err := os.Stat(p.Path); err != nil { rpkt = statusFromError(p.ID, err) } else if !stat.IsDir() { rpkt = statusFromError(p.ID, &os.PathError{ - Path: p.Path, Err: syscall.ENOTDIR}) + Path: p.Path, Err: syscall.ENOTDIR, + }) } else { rpkt = (&sshFxpOpenPacket{ ID: p.ID, @@ -446,7 +449,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { osFlags |= os.O_EXCL } - f, err := os.OpenFile(toLocalPath(p.Path), osFlags, 0644) + f, err := os.OpenFile(toLocalPath(svr.workDir, p.Path), osFlags, 0o644) if err != nil { return statusFromError(p.ID, err) } @@ -484,7 +487,7 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { b := p.Attrs.([]byte) var err error - p.Path = toLocalPath(p.Path) + p.Path = toLocalPath(svr.workDir, p.Path) debug("setstat name \"%s\"", p.Path) if (p.Flags & sshFileXferAttrSize) != 0 {