Skip to content

Commit

Permalink
rework client to prevent after-close usage, and support perm at open
Browse files Browse the repository at this point in the history
puellanivis committed Jan 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 22452ea commit d1903fb
Showing 9 changed files with 381 additions and 110 deletions.
19 changes: 17 additions & 2 deletions attrs.go
Original file line number Diff line number Diff line change
@@ -32,10 +32,10 @@ func (fi *fileInfo) Name() string { return fi.name }
func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) }

// Mode returns file mode bits.
func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) }
func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() }

// ModTime returns the last modification time of the file.
func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) }
func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() }

// IsDir returns true if the file is a directory.
func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() }
@@ -56,6 +56,21 @@ type FileStat struct {
Extended []StatExtended
}

// ModTime returns the Mtime SFTP file attribute converted to a time.Time
func (fs *FileStat) ModTime() time.Time {
return time.Unix(int64(fs.Mtime), 0)
}

// AccessTime returns the Atime SFTP file attribute converted to a time.Time
func (fs *FileStat) AccessTime() time.Time {
return time.Unix(int64(fs.Atime), 0)
}

// FileMode returns the Mode SFTP file attribute converted to an os.FileMode
func (fs *FileStat) FileMode() os.FileMode {
return toFileMode(fs.Mode)
}

// StatExtended contains additional, extended information for a FileStat.
type StatExtended struct {
ExtType string
156 changes: 131 additions & 25 deletions client.go
Original file line number Diff line number Diff line change
@@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
// read/write at the same time. For those services you will need to use
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
func (c *Client) Create(path string) (*File, error) {
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
}

const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
@@ -510,7 +510,7 @@ func (c *Client) Symlink(oldname, newname string) error {
}
}

func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
@@ -590,14 +590,14 @@ func (c *Client) Truncate(path string, size int64) error {
// returned file can be used for reading; the associated file descriptor
// has mode O_RDONLY.
func (c *Client) Open(path string) (*File, error) {
return c.open(path, flags(os.O_RDONLY))
return c.open(path, toPflags(os.O_RDONLY))
}

// OpenFile is the generalized open call; most users will use Open or
// Create instead. It opens the named file with specified flag (O_RDONLY
// etc.). If successful, methods on the returned File can be used for I/O.
func (c *Client) OpenFile(path string, f int) (*File, error) {
return c.open(path, flags(f))
return c.open(path, toPflags(f))
}

func (c *Client) open(path string, pflags uint32) (*File, error) {
@@ -976,16 +976,26 @@ func (c *Client) RemoveAll(path string) error {
type File struct {
c *Client
path string
handle string

mu sync.Mutex
mu sync.RWMutex
handle string
offset int64 // current offset within remote file
}

// Close closes the File, rendering it unusable for I/O. It returns an
// error, if any.
func (f *File) Close() error {
return f.c.close(f.handle)
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return os.ErrClosed
}

handle := f.handle
f.handle = ""

return f.c.close(handle)
}

// Name returns the name of the file as presented to Open or Create.
@@ -1006,7 +1016,11 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

n, err := f.ReadAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}

n, err := f.readAt(b, f.offset)
f.offset += int64(n)
return n, err
}
@@ -1071,6 +1085,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (int, error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.readAt(b, off)
}

func (f *File) readAt(b []byte, off int64) (int, error) {
if len(b) <= f.c.maxPacket {
// This should be able to be serviced with 1/2 requests.
// So, just do it directly.
@@ -1267,6 +1292,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

if f.c.disableConcurrentReads {
return f.writeToSequential(w)
}
@@ -1456,9 +1485,20 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
}
}

func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return nil, os.ErrClosed
}

return f.stat()
}

// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) Stat() (os.FileInfo, error) {
func (f *File) stat() (os.FileInfo, error) {
fs, err := f.c.fstat(f.handle)
if err != nil {
return nil, err
@@ -1478,7 +1518,11 @@ func (f *File) Write(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

n, err := f.WriteAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}

n, err := f.writeAt(b, f.offset)
f.offset += int64(n)
return n, err
}
@@ -1636,6 +1680,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
// so the file offset is not altered during the write.
func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.writeAt(b, off)
}

func (f *File) writeAt(b []byte, off int64) (written int, err error) {
if len(b) <= f.c.maxPacket {
// We can do this in one write.
return f.writeChunkAt(nil, b, off)
@@ -1675,6 +1730,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
//
// Otherwise, the given concurrency will be capped by the Client's max concurrency.
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

return f.readFromWithConcurrency(r, concurrency)
}

func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
// Split the write into multiple maxPacket sized concurrent writes.
// This allows writes with a suitably large reader
// to transfer data at a much faster rate due to overlapping round trip times.
@@ -1824,6 +1890,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

if f.c.useConcurrentWrites {
var remain int64
switch r := r.(type) {
@@ -1845,7 +1915,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {

if remain < 0 {
// We can strongly assert that we want default max concurrency here.
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests)
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
}

if remain > int64(f.c.maxPacket) {
@@ -1860,7 +1930,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
concurrency64 = int64(f.c.maxConcurrentRequests)
}

return f.ReadFromWithConcurrency(r, int(concurrency64))
return f.readFromWithConcurrency(r, int(concurrency64))
}
}

@@ -1903,12 +1973,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return 0, os.ErrClosed
}

switch whence {
case io.SeekStart:
case io.SeekCurrent:
offset += f.offset
case io.SeekEnd:
fi, err := f.Stat()
fi, err := f.stat()
if err != nil {
return f.offset, err
}
@@ -1927,20 +2001,61 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {

// Chown changes the uid/gid of the current file.
func (f *File) Chown(uid, gid int) error {
return f.c.Chown(f.path, uid, gid)
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
UID: uint32(uid),
GID: uint32(gid),
})
}

// Chmod changes the permissions of the current file.
//
// See Client.Chmod for details.
func (f *File) Chmod(mode os.FileMode) error {
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
}

// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
f.mu.RLock()
defer f.mu.RUnlock()

if f.handle == "" {
return os.ErrClosed
}

return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
}

// Sync requests a flush of the contents of a File to stable storage.
//
// Sync requires the server to support the fsync@openssh.com extension.
func (f *File) Sync() error {
f.mu.Lock()
defer f.mu.Unlock()

if f.handle == "" {
return os.ErrClosed
}


id := f.c.nextID()
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
@@ -1957,15 +2072,6 @@ func (f *File) Sync() error {
}
}

// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
}

// normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error {
@@ -1990,7 +2096,7 @@ func normaliseError(err error) error {

// flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored.
func flags(f int) uint32 {
func toPflags(f int) uint32 {
var out uint32
switch f & os.O_WRONLY {
case os.O_WRONLY:
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ var flagsTests = []struct {

func TestFlags(t *testing.T) {
for i, tt := range flagsTests {
got := flags(tt.flags)
got := toPflags(tt.flags)
if got != tt.want {
t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got)
}
Loading

0 comments on commit d1903fb

Please sign in to comment.