Skip to content

Commit

Permalink
Merge pull request #443 from pkg/patch/sequential-concurrent-readat-r…
Browse files Browse the repository at this point in the history
…equests

[bugfix] Sequentially issue read requests in ReadAt the same as WriteTo
  • Loading branch information
puellanivis authored Jun 30, 2021
2 parents 5b98d05 + 6617a3a commit d9a1139
Showing 1 changed file with 54 additions and 9 deletions.
63 changes: 54 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,17 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {

cancel := make(chan struct{})

concurrency := len(b)/f.c.maxPacket + 1
if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
concurrency = f.c.maxConcurrentRequests
}

resPool := newResChanPool(concurrency)

type work struct {
id uint32
res chan result

b []byte
off int64
}
Expand All @@ -1048,8 +1058,18 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
rb = rb[:chunkSize]
}

id := f.c.nextID()
res := resPool.Get()

f.c.dispatchRequest(res, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(offset),
Len: uint32(chunkSize),
})

select {
case workCh <- work{rb, offset}:
case workCh <- work{id, res, rb, offset}:
case <-cancel:
return
}
Expand All @@ -1065,22 +1085,47 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
}
errCh := make(chan rErr)

concurrency := len(b)/f.c.maxPacket + 1
if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
concurrency = f.c.maxConcurrentRequests
}

var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
// Map_i: each worker gets work, and then performs the Read into its buffer from its respective offset.
go func() {
defer wg.Done()

ch := make(chan result, 1) // reusable channel per mapper.

for packet := range workCh {
n, err := f.readChunkAt(ch, packet.b, packet.off)
var n int

s := <-packet.res
resPool.Put(packet.res)

err := s.err
if err == nil {
switch s.typ {
case sshFxpStatus:
err = normaliseError(unmarshalStatus(packet.id, s.data))

case sshFxpData:
sid, data := unmarshalUint32(s.data)
if packet.id != sid {
err = &unexpectedIDErr{packet.id, sid}

} else {
l, data := unmarshalUint32(data)
n = copy(packet.b, data[:l])

// For normal disk files, it is guaranteed that this will read
// the specified number of bytes, or up to end of file.
// This implies, if we have a short read, that means EOF.
if n < len(packet.b) {
err = io.EOF
}
}

default:
err = unimplementedPacketErr(s.typ)
}
}

if err != nil {
// return the offset as the start + how much we read before the error.
errCh <- rErr{packet.off + int64(n), err}
Expand Down

0 comments on commit d9a1139

Please sign in to comment.