Skip to content

Commit

Permalink
request server: add support for SSH_FXP_FSETSTAT
Browse files Browse the repository at this point in the history
we need to add a case for this packet inside the packet worker otherwise
it will be handled in hasHandle case and it will become a "Put" request.

Client side if a Truncate request is called on the open file we should
send a FSETSTAT packet, the request is on the handle, and not a SETSTAT
packet that should be used for paths and not for handle.
  • Loading branch information
drakkan committed Aug 22, 2020
1 parent a6e55f6 commit b485931
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 6 deletions.
40 changes: 34 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,25 @@ func (c *Client) Symlink(oldname, newname string) error {
}
}

func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(sshFxpFsetstatPacket{
ID: id,
Handle: handle,
Flags: flags,
Attrs: attrs,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}

// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
id := c.nextID()
Expand Down Expand Up @@ -817,7 +836,7 @@ type File struct {
path string
handle string

mu sync.Mutex
mu sync.Mutex
offset uint64 // current offset within remote file
}

Expand Down Expand Up @@ -845,13 +864,13 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()

r, err := f.ReadAt(b, int64( f.offset ))
r, err := f.ReadAt(b, int64(f.offset))
f.offset += uint64(r)
return r, err
}

// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (n int, err error) {
// Split the read into multiple maxPacket sized concurrent reads
Expand All @@ -860,7 +879,7 @@ func (f *File) ReadAt(b []byte, off int64) (n int, err error) {
// overlapping round trip times.
inFlight := 0
desiredInFlight := 1
offset := uint64( off )
offset := uint64(off)
// maxConcurrentRequests buffer to deal with broadcastErr() floods
// also must have a buffer of max value of (desiredInFlight - inFlight)
ch := make(chan result, f.c.maxConcurrentRequests+1)
Expand Down Expand Up @@ -1280,8 +1299,17 @@ func (f *File) Chmod(mode os.FileMode) error {
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.Truncate(f.path, size)
err := f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
if err == nil {
// reset the offset for future writes
f.mu.Lock()
defer f.mu.Unlock()

f.offset = uint64(size)
}
return err
}

func min(a, b int) int {
Expand Down
9 changes: 9 additions & 0 deletions request-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ func (rs *RequestServer) packetWorker(
request = NewRequest("Stat", request.Filepath)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = NewRequest("Setstat", request.Filepath)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case *sshFxpExtendedPacketPosixRename:
request := NewRequest("Rename", pkt.Oldpath)
request.Target = pkt.Newpath
Expand Down
34 changes: 34 additions & 0 deletions request-server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,40 @@ func TestRequestFstat(t *testing.T) {
checkRequestServerAllocator(t, p)
}

func TestRequestFsetstat(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
assert.Nil(t, err)
fp, err := p.cli.OpenFile("/foo", os.O_WRONLY)
assert.Nil(t, err)
err = fp.Truncate(2)
if assert.NoError(t, err) {
fi, err := fp.Stat()
if assert.NoError(t, err) {
assert.Equal(t, fi.Name(), "foo")
assert.Equal(t, fi.Size(), int64(2))
}
}
// we expect the truncate size (2) as offset for this write
n, err := fp.Write([]byte("hello"))
assert.NoError(t, err)
assert.Equal(t, 5, n)
err = fp.Close()
assert.NoError(t, err)
rf, err := p.cli.Open("/foo")
assert.Nil(t, err)
defer rf.Close()
contents := make([]byte, 20)
n, err = rf.Read(contents)
if err != nil && err != io.EOF {
t.Fatalf("err: %v", err)
}
assert.Equal(t, 2+5, n)
assert.Equal(t, "hehello", string(contents[0:n]))
checkRequestServerAllocator(t, p)
}

func TestRequestStatFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
Expand Down
19 changes: 19 additions & 0 deletions request_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ func (fs *root) Filecmd(r *Request) error {
defer fs.filesLock.Unlock()
switch r.Method {
case "Setstat":
file, err := fs.fetch(r.Filepath)
if err != nil {
return err
}
if r.AttrFlags().Size {
return file.Truncate(int64(r.Attributes().Size))
}
return nil
case "Rename":
file, err := fs.fetch(r.Filepath)
Expand Down Expand Up @@ -302,6 +309,18 @@ func (f *memFile) WriteAt(p []byte, off int64) (int, error) {
return len(p), nil
}

func (f *memFile) Truncate(size int64) error {
f.contentLock.Lock()
defer f.contentLock.Unlock()
grow := size - int64(len(f.content))
if grow <= 0 {
f.content = f.content[:size]
} else {
f.content = append(f.content, make([]byte, grow)...)
}
return nil
}

func (f *memFile) TransferError(err error) {
f.transferError = err
}

0 comments on commit b485931

Please sign in to comment.