Skip to content

Commit

Permalink
Merge pull request #373 from drakkan/fsetstat
Browse files Browse the repository at this point in the history
request server: add support for SSH_FXP_FSETSTAT
  • Loading branch information
drakkan authored Aug 25, 2020
2 parents 2c44234 + 07229f2 commit 06ab92e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ go_import_path: github.com/pkg/sftp
# current and previous stable releases, plus tip
# remember to exclude previous and tip for macs below
go:
- 1.13.x
- 1.14.x
- 1.15.x
- tip

os:
Expand All @@ -15,7 +15,7 @@ os:
matrix:
exclude:
- os: osx
go: 1.13.x
go: 1.14.x
- os: osx
go: tip

Expand Down
32 changes: 26 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,25 @@ func (c *Client) Symlink(oldname, newname string) error {
}
}

func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(sshFxpFsetstatPacket{
ID: id,
Handle: handle,
Flags: flags,
Attrs: attrs,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}

// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
id := c.nextID()
Expand Down Expand Up @@ -817,7 +836,7 @@ type File struct {
path string
handle string

mu sync.Mutex
mu sync.Mutex
offset uint64 // current offset within remote file
}

Expand Down Expand Up @@ -845,13 +864,13 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

r, err := f.ReadAt(b, int64( f.offset ))
r, err := f.ReadAt(b, int64(f.offset))
f.offset += uint64(r)
return r, err
}

// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (n int, err error) {
// Split the read into multiple maxPacket sized concurrent reads
Expand All @@ -860,7 +879,7 @@ func (f *File) ReadAt(b []byte, off int64) (n int, err error) {
// overlapping round trip times.
inFlight := 0
desiredInFlight := 1
offset := uint64( off )
offset := uint64(off)
// maxConcurrentRequests buffer to deal with broadcastErr() floods
// also must have a buffer of max value of (desiredInFlight - inFlight)
ch := make(chan result, f.c.maxConcurrentRequests+1)
Expand Down Expand Up @@ -1280,8 +1299,9 @@ func (f *File) Chmod(mode os.FileMode) error {
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.Truncate(f.path, size)
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
}

func min(a, b int) int {
Expand Down
19 changes: 19 additions & 0 deletions request-example.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ func (fs *root) Filecmd(r *Request) error {
defer fs.filesLock.Unlock()
switch r.Method {
case "Setstat":
file, err := fs.fetch(r.Filepath)
if err != nil {
return err
}
if r.AttrFlags().Size {
return file.Truncate(int64(r.Attributes().Size))
}
return nil
case "Rename":
file, err := fs.fetch(r.Filepath)
Expand Down Expand Up @@ -302,6 +309,18 @@ func (f *memFile) WriteAt(p []byte, off int64) (int, error) {
return len(p), nil
}

func (f *memFile) Truncate(size int64) error {
f.contentLock.Lock()
defer f.contentLock.Unlock()
grow := size - int64(len(f.content))
if grow <= 0 {
f.content = f.content[:size]
} else {
f.content = append(f.content, make([]byte, grow)...)
}
return nil
}

func (f *memFile) TransferError(err error) {
f.transferError = err
}
9 changes: 9 additions & 0 deletions request-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ func (rs *RequestServer) packetWorker(
request = NewRequest("Stat", request.Filepath)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = NewRequest("Setstat", request.Filepath)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case *sshFxpExtendedPacketPosixRename:
request := NewRequest("Rename", pkt.Oldpath)
request.Target = pkt.Newpath
Expand Down
32 changes: 32 additions & 0 deletions request-server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var _ = fmt.Print
Expand Down Expand Up @@ -345,6 +346,37 @@ func TestRequestFstat(t *testing.T) {
checkRequestServerAllocator(t, p)
}

func TestRequestFsetstat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
require.NoError(t, err)
fp, err := p.cli.OpenFile("/foo", os.O_WRONLY)
require.NoError(t, err)
err = fp.Truncate(2)
fi, err := fp.Stat()
require.NoError(t, err)
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(2))
err = fp.Truncate(5)
require.NoError(t, err)
fi, err = fp.Stat()
require.NoError(t, err)
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(5))
err = fp.Close()
assert.NoError(t, err)
rf, err := p.cli.Open("/foo")
assert.NoError(t, err)
defer rf.Close()
contents := make([]byte, 20)
n, err := rf.Read(contents)
assert.EqualError(t, err, io.EOF.Error())
assert.Equal(t, 5, n)
assert.Equal(t, []byte{'h', 'e', 0, 0, 0}, contents[0:n])
checkRequestServerAllocator(t, p)
}

func TestRequestStatFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
Expand Down

0 comments on commit 06ab92e

Please sign in to comment.