From 6079724764d101937b49b596b1cf523c537e635d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9B=B9=E6=98=8E=E9=B8=A3?= Date: Wed, 14 Sep 2022 16:11:32 +0800 Subject: [PATCH] fix `ReadFrom` generate corrupted bitset when reader incompletelyb fills buf --- bitset.go | 2 +- bitset_test.go | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/bitset.go b/bitset.go index 3829a2b..32d7cfb 100644 --- a/bitset.go +++ b/bitset.go @@ -945,7 +945,7 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) { nWords := wordsNeeded(uint(length)) reader := bufio.NewReader(io.LimitReader(stream, 8*int64(nWords))) for i := 0; i < nWords; i++ { - if _, err := reader.Read(item[:]); err != nil { + if _, err := io.ReadFull(reader, item[:]); err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } diff --git a/bitset_test.go b/bitset_test.go index 8d86be1..4bc11d7 100644 --- a/bitset_test.go +++ b/bitset_test.go @@ -1713,18 +1713,38 @@ func TestWriteTo(t *testing.T) { } } +type inCompleteRetBufReader struct { + returnEvery int64 + reader io.Reader + offset int64 +} + +func (ir *inCompleteRetBufReader) Read(b []byte) (n int, err error) { + if ir.returnEvery > 0 { + maxRead := ir.returnEvery - (ir.offset % ir.returnEvery) + if len(b) > int(maxRead) { + b = b[:maxRead] + } + } + n, err = ir.reader.Read(b) + ir.offset += int64(n) + return +} + func TestReadFrom(t *testing.T) { addBuf := []byte(`12345678`) // Bytes after stream tests := []struct { - length uint - oneEvery uint - input string // base64+gzipped - wantErr error + length uint + oneEvery uint + input string // base64+gzipped + wantErr error + returnEvery int64 }{ { - length: 9585, - oneEvery: 97, - input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwMSACVgYGBg4sIgLMDAwKGARd2BgYGjAFB41noDx6IAJajw64IAajw4UoMajg4ZR4/EaP5pQh1g+MDQyNjE1M7cABAAA//9W5OoOwAQAAA==", + length: 9585, + oneEvery: 97, + input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwMSACVgYGBg4sIgLMDAwKGARd2BgYGjAFB41noDx6IAJajw64IAajw4UoMajg4ZR4/EaP5pQh1g+MDQyNjE1M7cABAAA//9W5OoOwAQAAA==", + returnEvery: 127, }, { length: 1337, @@ -1764,7 +1784,7 @@ func TestReadFrom(t *testing.T) { fatalErr(gz.Close()) bs := New(test.length) - _, err = bs.ReadFrom(&buf) + _, err = bs.ReadFrom(&inCompleteRetBufReader{returnEvery: test.returnEvery, reader: &buf}) if err != nil { if errors.Is(err, test.wantErr) { // Correct, nothing more we can test.