Skip to content

Commit 916de2d

Browse files
committed
add zstdProposalDecompressor
1 parent b4f8360 commit 916de2d

File tree

2 files changed

+49
-21
lines changed

2 files changed

+49
-21
lines changed

network/msgCompressor.go

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,26 @@ const MaxDecompressedMessageSize = 20 * 1024 * 1024 // some large enough value
7676
// wsPeerMsgDataConverter performs optional incoming messages conversion.
7777
// At the moment it only supports zstd decompression for payload proposal
7878
type wsPeerMsgDataConverter struct {
79-
log logging.Logger
80-
origin string
81-
shouldDecompressProposalPayload bool
79+
log logging.Logger
80+
origin string
81+
82+
// actual converter(s)
83+
ppdec zstdProposalDecompressor
84+
}
85+
86+
type zstdProposalDecompressor struct {
87+
active bool
88+
}
89+
90+
func (dec zstdProposalDecompressor) enabled() bool {
91+
return dec.active
92+
}
93+
94+
func (dec zstdProposalDecompressor) accept(data []byte) bool {
95+
return len(data) > 4 && bytes.Equal(data[:4], zstdCompressionMagic[:])
8296
}
8397

84-
func (c *wsPeerMsgDataConverter) zstdDecompress(data []byte) ([]byte, error) {
98+
func (dec zstdProposalDecompressor) convert(data []byte) ([]byte, error) {
8599
r := zstd.NewReader(bytes.NewReader(data))
86100
defer r.Close()
87101
b := make([]byte, 0, 1024)
@@ -98,27 +112,40 @@ func (c *wsPeerMsgDataConverter) zstdDecompress(data []byte) ([]byte, error) {
98112
return nil, err
99113
}
100114
if len(b) > MaxDecompressedMessageSize {
101-
return nil, fmt.Errorf("proposal from peer %s data is too large: %d", c.origin, len(b))
115+
return nil, fmt.Errorf("proposal data is too large: %d", len(b))
102116
}
103117
}
104118
}
105119

106120
func (c *wsPeerMsgDataConverter) convert(tag protocol.Tag, data []byte) ([]byte, error) {
107-
if tag == protocol.ProposalPayloadTag && c.shouldDecompressProposalPayload {
108-
// sender might support compressed payload but fail to compress for whatever reason,
109-
// in this case it sends non-compressed payload - the receiver decompress only if it is compressed.
110-
if len(data) > 4 && bytes.Equal(data[:4], zstdCompressionMagic[:]) {
111-
return c.zstdDecompress(data)
121+
if tag == protocol.ProposalPayloadTag {
122+
if c.ppdec.enabled() {
123+
// sender might support compressed payload but fail to compress for whatever reason,
124+
// in this case it sends non-compressed payload - the receiver decompress only if it is compressed.
125+
if c.ppdec.accept(data) {
126+
res, err := c.ppdec.convert(data)
127+
if err != nil {
128+
return nil, fmt.Errorf("peer %s: %w", c.origin, err)
129+
}
130+
return res, nil
131+
}
132+
c.log.Warnf("peer %s supported zstd but sent non-compressed data", c.origin)
112133
}
113-
c.log.Warnf("peer %s supported zstd but sent non-compressed data", c.origin)
114134
}
115135
return data, nil
116136
}
117137

118138
func makeWsPeerMsgDataConverter(wp *wsPeer) *wsPeerMsgDataConverter {
119-
return &wsPeerMsgDataConverter{
120-
log: wp.net.log,
121-
origin: wp.originAddress,
122-
shouldDecompressProposalPayload: wp.vfCompressedProposalSupported(),
139+
c := wsPeerMsgDataConverter{
140+
log: wp.net.log,
141+
origin: wp.originAddress,
142+
}
143+
144+
if wp.vfCompressedProposalSupported() {
145+
c.ppdec = zstdProposalDecompressor{
146+
active: true,
147+
}
123148
}
149+
150+
return &c
124151
}

network/msgCompressor_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@ func TestZstdDecompress(t *testing.T) {
3434
msg := []byte(strings.Repeat("1", 2048))
3535
compressed, err := zstd.Compress(nil, msg)
3636
require.NoError(t, err)
37-
c := wsPeerMsgDataConverter{}
38-
decompressed, err := c.zstdDecompress(compressed)
37+
d := zstdProposalDecompressor{}
38+
decompressed, err := d.convert(compressed)
3939
require.NoError(t, err)
4040
require.Equal(t, msg, decompressed)
4141

4242
// error case - large message
4343
msg = []byte(strings.Repeat("1", MaxDecompressedMessageSize+10))
4444
compressed, err = zstd.Compress(nil, msg)
4545
require.NoError(t, err)
46-
decompressed, err = c.zstdDecompress(compressed)
46+
decompressed, err = d.convert(compressed)
4747
require.Error(t, err)
4848
require.Nil(t, decompressed)
4949
}
@@ -93,8 +93,8 @@ func TestZstdCompressMsg(t *testing.T) {
9393
require.Empty(t, msg)
9494
require.Equal(t, []byte(protocol.ProposalPayloadTag), comp[:ppt])
9595
require.Equal(t, zstdCompressionMagic[:], comp[ppt:ppt+len(zstdCompressionMagic)])
96-
c := wsPeerMsgDataConverter{}
97-
decompressed, err := c.zstdDecompress(comp[ppt:])
96+
d := zstdProposalDecompressor{}
97+
decompressed, err := d.convert(comp[ppt:])
9898
require.NoError(t, err)
9999
require.Equal(t, data, decompressed)
100100
}
@@ -113,6 +113,7 @@ func TestWsPeerMsgDataConverterConvert(t *testing.T) {
113113
partitiontest.PartitionTest(t)
114114

115115
c := wsPeerMsgDataConverter{}
116+
c.ppdec = zstdProposalDecompressor{active: false}
116117
tag := protocol.AgreementVoteTag
117118
data := []byte("data")
118119

@@ -127,7 +128,7 @@ func TestWsPeerMsgDataConverterConvert(t *testing.T) {
127128

128129
l := converterTestLogger{}
129130
c.log = &l
130-
c.shouldDecompressProposalPayload = true
131+
c.ppdec = zstdProposalDecompressor{active: true}
131132
r, err = c.convert(tag, data)
132133
require.NoError(t, err)
133134
require.Equal(t, data, r)

0 commit comments

Comments
 (0)