Skip to content

Commit

Permalink
bitarray: don't allow FromEncodingParts to return invalid bit array
Browse files Browse the repository at this point in the history
It is invalid for a bit array's lastBitsUsed field to be greater than
64. The FromEncodingParts function, however, would happily construct
an invalid bitarray if given a too-large lastBitsUsed value. Teach the
FromEncodingParts to return an error instead.

This presented as a panic when attempting to pretty-print a key with a
bitarray whose lastBitsUsed encoded value was 65. Such a key can be
created when calling PrefixEnd on a key with a bitarray whose
lastBitsUsed value is 64. By returning an error instead, the
pretty-printing code will try again after calling UndoPrefixEnd and be
able to print the key.

Fix #31115.

Release note: None
  • Loading branch information
benesch committed Oct 22, 2018
1 parent 2998190 commit eaf5808
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
5 changes: 5 additions & 0 deletions pkg/keys/printer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ func TestPrettyPrint(t *testing.T) {
{makeKey(MakeTablePrefix(42),
roachpb.RKey(encoding.EncodeBitArrayDescending(nil, bitArray))),
"/Table/42/B00111010"},
// Regression test for #31115.
{roachpb.Key(makeKey(MakeTablePrefix(42),
roachpb.RKey(encoding.EncodeBitArrayAscending(nil, bitarray.MakeZeroBitArray(64))),
)).PrefixEnd(),
"/Table/42/B0000000000000000000000000000000000000000000000000000000000000000/PrefixEnd"},
{makeKey(MakeTablePrefix(42),
roachpb.RKey(durationAsc)),
"/Table/42/1mon1d1s"},
Expand Down
23 changes: 18 additions & 5 deletions pkg/util/bitarray/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (d BitArray) Clone() BitArray {
// MakeZeroBitArray creates a bit array with the specified bit size.
func MakeZeroBitArray(bitLen uint) BitArray {
a, b := EncodingPartsForBitLen(bitLen)
return FromEncodingParts(a, b)
return mustFromEncodingParts(a, b)
}

// ToWidth resizes the bit array to the specified size.
Expand All @@ -128,7 +128,7 @@ func (d BitArray) ToWidth(desiredLen uint) BitArray {
words, lastBitsUsed := EncodingPartsForBitLen(desiredLen)
copy(words, d.words[:len(words)])
words[len(words)-1] &= (^word(0) << (numBitsPerWord - lastBitsUsed))
return FromEncodingParts(words, lastBitsUsed)
return mustFromEncodingParts(words, lastBitsUsed)
}

// New length is larger.
Expand All @@ -140,7 +140,7 @@ func (d BitArray) ToWidth(desiredLen uint) BitArray {
words = make([]word, numWords)
copy(words, d.words)
}
return FromEncodingParts(words, lastBitsUsed)
return mustFromEncodingParts(words, lastBitsUsed)
}

// Sizeof returns the size in bytes of the bit array and its components.
Expand Down Expand Up @@ -346,7 +346,7 @@ func Parse(s string) (res BitArray, err error) {
words[wordIdx] = curWord
}

return FromEncodingParts(words, lastBitsUsed), nil
return FromEncodingParts(words, lastBitsUsed)
}

// Concat concatenates two bit arrays.
Expand Down Expand Up @@ -481,11 +481,24 @@ func (d BitArray) EncodingParts() ([]uint64, uint64) {
}

// FromEncodingParts creates a bit array from the encoding parts.
func FromEncodingParts(words []uint64, lastBitsUsed uint64) BitArray {
func FromEncodingParts(words []uint64, lastBitsUsed uint64) (BitArray, error) {
if lastBitsUsed > numBitsPerWord {
return BitArray{}, fmt.Errorf("FromEncodingParts: lastBitsUsed must not exceed %d, got %d",
numBitsPerWord, lastBitsUsed)
}
return BitArray{
words: words,
lastBitsUsed: uint8(lastBitsUsed),
}, nil
}

// mustFromEncodingParts is like FromEncodingParts but errors cause a panic.
func mustFromEncodingParts(words []uint64, lastBitsUsed uint64) BitArray {
ba, err := FromEncodingParts(words, lastBitsUsed)
if err != nil {
panic(err)
}
return ba
}

// Rand generates a random bit array of the specified length.
Expand Down
27 changes: 27 additions & 0 deletions pkg/util/bitarray/bitarray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,33 @@ func TestParseFormat(t *testing.T) {
}
}

func TestFromEncodingParts(t *testing.T) {
testData := []struct {
words []uint64
lastBitsUsed uint64
ba BitArray
err string
}{
{nil, 0, BitArray{words: nil, lastBitsUsed: 0}, ""},
{[]uint64{0}, 0, BitArray{words: []word{0}, lastBitsUsed: 0}, ""},
{[]uint64{42}, 3, BitArray{words: []word{42}, lastBitsUsed: 3}, ""},
{[]uint64{42}, 65, BitArray{}, "FromEncodingParts: lastBitsUsed must not exceed 64, got 65"},
}

for _, test := range testData {
t.Run(fmt.Sprintf("{%v,%d}", test.words, test.lastBitsUsed), func(t *testing.T) {
ba, err := FromEncodingParts(test.words, test.lastBitsUsed)
if test.err != "" && (err == nil || test.err != err.Error()) {
t.Errorf("expected %q error, but got: %+v", test.err, err)
} else if test.err == "" && err != nil {
t.Errorf("unexpected error: %s", err)
} else if !reflect.DeepEqual(ba, test.ba) {
t.Errorf("expected %+v, got %+v", test.ba, ba)
}
})
}
}

func TestToWidth(t *testing.T) {
testData := []struct {
str string
Expand Down
15 changes: 12 additions & 3 deletions pkg/util/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,11 @@ func DecodeBitArrayAscending(b []byte) ([]byte, bitarray.BitArray, error) {
}
b = b[1:]
b, lastVal, err := DecodeUvarintAscending(b)
return b, bitarray.FromEncodingParts(words, lastVal), err
if err != nil {
return b, bitarray.BitArray{}, err
}
ba, err := bitarray.FromEncodingParts(words, lastVal)
return b, ba, err
}

var errBitArrayTerminatorMissing = errors.New("cannot find bit array data terminator")
Expand Down Expand Up @@ -1165,7 +1169,11 @@ func DecodeBitArrayDescending(b []byte) ([]byte, bitarray.BitArray, error) {
}
b = b[1:]
b, lastVal, err := DecodeUvarintDescending(b)
return b, bitarray.FromEncodingParts(words, lastVal), err
if err != nil {
return b, bitarray.BitArray{}, err
}
ba, err := bitarray.FromEncodingParts(words, lastVal)
return b, ba, err
}

// Type represents the type of a value encoded by
Expand Down Expand Up @@ -2121,7 +2129,8 @@ func DecodeUntaggedBitArrayValue(b []byte) (remaining []byte, d bitarray.BitArra
}
words[i] = val
}
return b, bitarray.FromEncodingParts(words, lastBitsUsed), nil
ba, err := bitarray.FromEncodingParts(words, lastBitsUsed)
return b, ba, err
}

const uuidValueEncodedLength = 16
Expand Down

0 comments on commit eaf5808

Please sign in to comment.