-
Notifications
You must be signed in to change notification settings - Fork 125
/
bitvec.go
122 lines (104 loc) · 2.58 KB
/
bitvec.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Copyright 2024 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only
package parachaintypes
import (
"bytes"
"fmt"
"io"
"github.com/ChainSafe/gossamer/pkg/scale"
)
const byteSize = 8
const bitVecMaxLength = 268435455
var errBitVecTooLong = fmt.Errorf("bitvec too long")
// BitVec is the implementation of a bit vector
type BitVec struct {
bits []bool
}
// NewBitVec returns a new BitVec with the given bits
// This isn't a complete implementation of the bit vector
// It is only used for ParachainHost runtime exports
// TODO: Implement the full bit vector
// https://github.com/ChainSafe/gossamer/issues/3248
func NewBitVec(bits []bool) BitVec {
return BitVec{
bits: bits,
}
}
// bitsToBytes converts a slice of bits to a slice of bytes
// Uses lsb ordering
// TODO: Implement msb ordering
// https://github.com/ChainSafe/gossamer/issues/3248
func (bv *BitVec) bytes() []byte {
bits := bv.bits
bitLength := len(bits)
numOfBytes := (bitLength + (byteSize - 1)) / byteSize
bytes := make([]byte, numOfBytes)
if len(bits)%byteSize != 0 {
// Pad with zeros to make the number of bits a multiple of byteSize
pad := make([]bool, byteSize-len(bits)%byteSize)
bits = append(bits, pad...)
}
for i := 0; i < bitLength; i++ {
if bits[i] {
byteIndex := i / byteSize
bitIndex := i % byteSize
bytes[byteIndex] |= 1 << bitIndex
}
}
return bytes
}
// bytesToBits converts a slice of bytes to a slice of bits
func (bv *BitVec) setBits(b []byte, size uint) {
var bits []bool
for _, uint8val := range b {
end := size
if end > byteSize {
end = byteSize
}
size -= end
for j := uint(0); j < end; j++ {
bit := (uint8val>>j)&1 == 1
bits = append(bits, bit)
}
}
bv.bits = bits
}
// MarshalSCALE fulfils the SCALE interface for encoding
func (bv BitVec) MarshalSCALE() ([]byte, error) {
buf := bytes.NewBuffer(nil)
encoder := scale.NewEncoder(buf)
if len(bv.bits) > bitVecMaxLength {
return nil, errBitVecTooLong
}
size := uint(len(bv.bits))
err := encoder.Encode(size)
if err != nil {
return nil, err
}
bytes := bv.bytes()
_, err = buf.Write(bytes)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// UnmarshalSCALE fulfils the SCALE interface for decoding
func (bv *BitVec) UnmarshalSCALE(r io.Reader) error {
decoder := scale.NewDecoder(r)
var size uint
err := decoder.Decode(&size)
if err != nil {
return err
}
if size > bitVecMaxLength {
return errBitVecTooLong
}
numBytes := (size + (byteSize - 1)) / byteSize
b := make([]byte, numBytes)
_, err = r.Read(b)
if err != nil {
return err
}
bv.setBits(b, size)
return nil
}