diff --git a/client.go b/client.go index aa17a392..a1aad7b3 100644 --- a/client.go +++ b/client.go @@ -1028,7 +1028,17 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { cancel := make(chan struct{}) + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + resPool := newResChanPool(concurrency) + type work struct { + id uint32 + res chan result + b []byte off int64 } @@ -1048,8 +1058,18 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { rb = rb[:chunkSize] } + id := f.c.nextID() + res := resPool.Get() + + f.c.dispatchRequest(res, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(offset), + Len: uint32(chunkSize), + }) + select { - case workCh <- work{rb, offset}: + case workCh <- work{id, res, rb, offset}: case <-cancel: return } @@ -1065,11 +1085,6 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { } errCh := make(chan rErr) - concurrency := len(b)/f.c.maxPacket + 1 - if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { - concurrency = f.c.maxConcurrentRequests - } - var wg sync.WaitGroup wg.Add(concurrency) for i := 0; i < concurrency; i++ { @@ -1077,10 +1092,40 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { go func() { defer wg.Done() - ch := make(chan result, 1) // reusable channel per mapper. - for packet := range workCh { - n, err := f.readChunkAt(ch, packet.b, packet.off) + var n int + + s := <-packet.res + resPool.Put(packet.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(packet.id, s.data)) + + case sshFxpData: + sid, data := unmarshalUint32(s.data) + if packet.id != sid { + err = &unexpectedIDErr{packet.id, sid} + + } else { + l, data := unmarshalUint32(data) + n = copy(packet.b, data[:l]) + + // For normal disk files, it is guaranteed that this will read + // the specified number of bytes, or up to end of file. + // This implies, if we have a short read, that means EOF. + if n < len(packet.b) { + err = io.EOF + } + } + + default: + err = unimplementedPacketErr(s.typ) + } + } + if err != nil { // return the offset as the start + how much we read before the error. errCh <- rErr{packet.off + int64(n), err}