diff --git a/go.mod b/go.mod index f20b25dc6d..96ddeb7044 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( ) require ( + github.com/DataDog/zstd v1.5.2 // indirect github.com/chrismcguire/gobberish v0.0.0-20150821175641-1d8adb509a0e // indirect github.com/cpuguy83/go-md2man v1.0.8 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 58fbbdb98c..548f9d7a0d 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= +github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/algorand/avm-abi v0.1.0 h1:znZFQXpSUVYz37vXbaH5OZG2VK4snTyXwnc/tV9CVr4= github.com/algorand/avm-abi v0.1.0/go.mod h1:+CgwM46dithy850bpTeHh9MC99zpn2Snirb3QTl2O/g= github.com/algorand/falcon v0.0.0-20220727072124-02a2a64c4414 h1:nwYN+GQ7Z5OOfZwqBO1ma7DSlP7S1YrKWICOyjkwqrc= diff --git a/network/msgCompressor.go b/network/msgCompressor.go new file mode 100644 index 0000000000..d6991163f1 --- /dev/null +++ b/network/msgCompressor.go @@ -0,0 +1,157 @@ +// Copyright (C) 2019-2022 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package network + +import ( + "bytes" + "fmt" + "io" + + "github.com/DataDog/zstd" + + "github.com/algorand/go-algorand/logging" + "github.com/algorand/go-algorand/protocol" +) + +var zstdCompressionMagic = [4]byte{0x28, 0xb5, 0x2f, 0xfd} + +const zstdCompressionLevel = zstd.BestSpeed + +// checkCanCompress checks if there is an proposal payload message and peers supporting compression +func checkCanCompress(request broadcastRequest, peers []*wsPeer) bool { + canCompress := false + hasPP := false + for _, tag := range request.tags { + if tag == protocol.ProposalPayloadTag { + hasPP = true + break + } + } + // if have proposal payload check if there are any peers supporting compression + if hasPP { + for _, peer := range peers { + if peer.pfProposalCompressionSupported() { + canCompress = true + break + } + } + } + return canCompress +} + +// zstdCompressMsg returns a concatenation of a tag and compressed data +func zstdCompressMsg(tbytes []byte, d []byte) ([]byte, string) { + bound := zstd.CompressBound(len(d)) + if bound < len(d) { + // although CompressBound allocated more than the src size, this is an implementation detail. + // increase the buffer size to always have enough space for the raw data if compression fails. + bound = len(d) + } + mbytesComp := make([]byte, len(tbytes)+bound) + copy(mbytesComp, tbytes) + comp, err := zstd.CompressLevel(mbytesComp[len(tbytes):], d, zstdCompressionLevel) + if err != nil { + // fallback and reuse non-compressed original data + logMsg := fmt.Sprintf("failed to compress into buffer of len %d: %v", len(d), err) + copied := copy(mbytesComp[len(tbytes):], d) + return mbytesComp[:len(tbytes)+copied], logMsg + } + mbytesComp = mbytesComp[:len(tbytes)+len(comp)] + return mbytesComp, "" +} + +// MaxDecompressedMessageSize defines a maximum decompressed data size +// to prevent zip bombs +const MaxDecompressedMessageSize = 20 * 1024 * 1024 // some large enough value + +// wsPeerMsgDataConverter performs optional incoming messages conversion. +// At the moment it only supports zstd decompression for payload proposal +type wsPeerMsgDataConverter struct { + log logging.Logger + origin string + + // actual converter(s) + ppdec zstdProposalDecompressor +} + +type zstdProposalDecompressor struct { + active bool +} + +func (dec zstdProposalDecompressor) enabled() bool { + return dec.active +} + +func (dec zstdProposalDecompressor) accept(data []byte) bool { + return len(data) > 4 && bytes.Equal(data[:4], zstdCompressionMagic[:]) +} + +func (dec zstdProposalDecompressor) convert(data []byte) ([]byte, error) { + r := zstd.NewReader(bytes.NewReader(data)) + defer r.Close() + b := make([]byte, 0, 3*len(data)) + for { + if len(b) == cap(b) { + // grow capacity, retain length + b = append(b, 0)[:len(b)] + } + n, err := r.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if err == io.EOF { + return b, nil + } + return nil, err + } + if len(b) > MaxDecompressedMessageSize { + return nil, fmt.Errorf("proposal data is too large: %d", len(b)) + } + } +} + +func (c *wsPeerMsgDataConverter) convert(tag protocol.Tag, data []byte) ([]byte, error) { + if tag == protocol.ProposalPayloadTag { + if c.ppdec.enabled() { + // sender might support compressed payload but fail to compress for whatever reason, + // in this case it sends non-compressed payload - the receiver decompress only if it is compressed. + if c.ppdec.accept(data) { + res, err := c.ppdec.convert(data) + if err != nil { + return nil, fmt.Errorf("peer %s: %w", c.origin, err) + } + return res, nil + } + c.log.Warnf("peer %s supported zstd but sent non-compressed data", c.origin) + } + } + return data, nil +} + +func makeWsPeerMsgDataConverter(wp *wsPeer) *wsPeerMsgDataConverter { + c := wsPeerMsgDataConverter{ + log: wp.net.log, + origin: wp.originAddress, + } + + if wp.pfProposalCompressionSupported() { + c.ppdec = zstdProposalDecompressor{ + active: true, + } + } + + return &c +} diff --git a/network/msgCompressor_test.go b/network/msgCompressor_test.go new file mode 100644 index 0000000000..0a8713c870 --- /dev/null +++ b/network/msgCompressor_test.go @@ -0,0 +1,142 @@ +// Copyright (C) 2019-2022 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package network + +import ( + "strings" + "testing" + + "github.com/DataDog/zstd" + "github.com/algorand/go-algorand/logging" + "github.com/algorand/go-algorand/protocol" + "github.com/algorand/go-algorand/test/partitiontest" + "github.com/stretchr/testify/require" +) + +func TestZstdDecompress(t *testing.T) { + partitiontest.PartitionTest(t) + + // happy case - small message + msg := []byte(strings.Repeat("1", 2048)) + compressed, err := zstd.Compress(nil, msg) + require.NoError(t, err) + d := zstdProposalDecompressor{} + decompressed, err := d.convert(compressed) + require.NoError(t, err) + require.Equal(t, msg, decompressed) + + // error case - large message + msg = []byte(strings.Repeat("1", MaxDecompressedMessageSize+10)) + compressed, err = zstd.Compress(nil, msg) + require.NoError(t, err) + decompressed, err = d.convert(compressed) + require.Error(t, err) + require.Nil(t, decompressed) +} + +func TestCheckCanCompress(t *testing.T) { + partitiontest.PartitionTest(t) + + req := broadcastRequest{} + peers := []*wsPeer{} + r := checkCanCompress(req, peers) + require.False(t, r) + + req.tags = []protocol.Tag{protocol.AgreementVoteTag} + r = checkCanCompress(req, peers) + require.False(t, r) + + req.tags = []protocol.Tag{protocol.AgreementVoteTag, protocol.ProposalPayloadTag} + r = checkCanCompress(req, peers) + require.False(t, r) + + peer1 := wsPeer{ + features: 0, + } + peers = []*wsPeer{&peer1} + r = checkCanCompress(req, peers) + require.False(t, r) + + peer2 := wsPeer{ + features: pfCompressedProposal, + } + peers = []*wsPeer{&peer1, &peer2} + r = checkCanCompress(req, peers) + require.True(t, r) +} + +func TestZstdCompressMsg(t *testing.T) { + partitiontest.PartitionTest(t) + + ppt := len(protocol.ProposalPayloadTag) + data := []byte("data") + comp, msg := zstdCompressMsg([]byte(protocol.ProposalPayloadTag), data) + require.Empty(t, msg) + require.Equal(t, []byte(protocol.ProposalPayloadTag), comp[:ppt]) + require.Equal(t, zstdCompressionMagic[:], comp[ppt:ppt+len(zstdCompressionMagic)]) + d := zstdProposalDecompressor{} + decompressed, err := d.convert(comp[ppt:]) + require.NoError(t, err) + require.Equal(t, data, decompressed) +} + +type converterTestLogger struct { + logging.Logger + WarnfCallback func(string, ...interface{}) + warnMsgCount int +} + +func (cl *converterTestLogger) Warnf(s string, args ...interface{}) { + cl.warnMsgCount++ +} + +func TestWsPeerMsgDataConverterConvert(t *testing.T) { + partitiontest.PartitionTest(t) + + c := wsPeerMsgDataConverter{} + c.ppdec = zstdProposalDecompressor{active: false} + tag := protocol.AgreementVoteTag + data := []byte("data") + + r, err := c.convert(tag, data) + require.NoError(t, err) + require.Equal(t, data, r) + + tag = protocol.ProposalPayloadTag + r, err = c.convert(tag, data) + require.NoError(t, err) + require.Equal(t, data, r) + + l := converterTestLogger{} + c.log = &l + c.ppdec = zstdProposalDecompressor{active: true} + r, err = c.convert(tag, data) + require.NoError(t, err) + require.Equal(t, data, r) + require.Equal(t, 1, l.warnMsgCount) + + l = converterTestLogger{} + c.log = &l + + comp, err := zstd.Compress(nil, data) + require.NoError(t, err) + + r, err = c.convert(tag, comp) + require.NoError(t, err) + require.Equal(t, data, r) + require.Equal(t, 0, l.warnMsgCount) +} diff --git a/network/wsNetwork.go b/network/wsNetwork.go index 096dedf7b2..812d4db221 100644 --- a/network/wsNetwork.go +++ b/network/wsNetwork.go @@ -126,10 +126,15 @@ var peers = metrics.MakeGauge(metrics.MetricName{Name: "algod_network_peers", De var incomingPeers = metrics.MakeGauge(metrics.MetricName{Name: "algod_network_incoming_peers", Description: "Number of active incoming peers."}) var outgoingPeers = metrics.MakeGauge(metrics.MetricName{Name: "algod_network_outgoing_peers", Description: "Number of active outgoing peers."}) -// peerDisconnectionAckDuration defines the time we would wait for the peer disconnection to compelete. +var networkPrioBatchesPPWithCompression = metrics.MakeCounter(metrics.MetricName{Name: "algod_network_prio_batches_wpp_comp_sent_total", Description: "number of prio compressed batches with PP"}) +var networkPrioBatchesPPWithoutCompression = metrics.MakeCounter(metrics.MetricName{Name: "algod_network_pp_prio_batches_wpp_non_comp_sent_total", Description: "number of prio non-compressed batches with PP"}) +var networkPrioPPCompressedSize = metrics.MakeCounter(metrics.MetricName{Name: "algod_network_prio_pp_compressed_size_total", Description: "cumulative size of all compressed PP"}) +var networkPrioPPNonCompressedSize = metrics.MakeCounter(metrics.MetricName{Name: "algod_network_prio_pp_non_compressed_size_total", Description: "cumulative size of all non-compressed PP"}) + +// peerDisconnectionAckDuration defines the time we would wait for the peer disconnection to complete. const peerDisconnectionAckDuration = 5 * time.Second -// peerShutdownDisconnectionAckDuration defines the time we would wait for the peer disconnection to compelete during shutdown. +// peerShutdownDisconnectionAckDuration defines the time we would wait for the peer disconnection to complete during shutdown. const peerShutdownDisconnectionAckDuration = 50 * time.Millisecond // Peer opaque interface for referring to a neighbor in the network @@ -386,7 +391,7 @@ type WebsocketNetwork struct { // connPerfMonitor is used on outgoing connections to measure their relative message timing connPerfMonitor *connectionPerformanceMonitor - // lastNetworkAdvanceMu syncronized the access to lastNetworkAdvance + // lastNetworkAdvanceMu synchronized the access to lastNetworkAdvance lastNetworkAdvanceMu deadlock.Mutex // lastNetworkAdvance contains the last timestamp where the agreement protocol was able to make a notable progress. @@ -430,6 +435,13 @@ type WebsocketNetwork struct { // atomic {0:unknown, 1:yes, 2:no} wantTXGossip uint32 + + // supportedProtocolVersions defines versions supported by this network. + // Should be used instead of a global network.SupportedProtocolVersions for network/peers configuration + supportedProtocolVersions []string + + // protocolVersion is an actual version announced as ProtocolVersionHeader + protocolVersion string } const ( @@ -760,10 +772,16 @@ func (wn *WebsocketNetwork) setup() { wn.lastNetworkAdvance = time.Now().UTC() wn.handlers.log = wn.log + // set our supported versions if wn.config.NetworkProtocolVersion != "" { - SupportedProtocolVersions = []string{wn.config.NetworkProtocolVersion} + wn.supportedProtocolVersions = []string{wn.config.NetworkProtocolVersion} + } else { + wn.supportedProtocolVersions = SupportedProtocolVersions } + // set our actual version + wn.protocolVersion = ProtocolVersion + wn.messagesOfInterestRefresh = make(chan struct{}, 2) wn.messagesOfInterestGeneration = 1 // something nonzero so that any new wsPeer needs updating if wn.relayMessages { @@ -930,7 +948,7 @@ func (wn *WebsocketNetwork) setHeaders(header http.Header) { func (wn *WebsocketNetwork) checkServerResponseVariables(otherHeader http.Header, addr string) (bool, string) { matchingVersion, otherVersion := wn.checkProtocolVersionMatch(otherHeader) if matchingVersion == "" { - wn.log.Info(filterASCII(fmt.Sprintf("new peer %s version mismatch, mine=%v theirs=%s, headers %#v", addr, SupportedProtocolVersions, otherVersion, otherHeader))) + wn.log.Info(filterASCII(fmt.Sprintf("new peer %s version mismatch, mine=%v theirs=%s, headers %#v", addr, wn.supportedProtocolVersions, otherVersion, otherHeader))) return false, "" } otherRandom := otherHeader.Get(NodeRandomHeader) @@ -1003,7 +1021,7 @@ func (wn *WebsocketNetwork) checkProtocolVersionMatch(otherHeaders http.Header) otherAcceptedVersions := otherHeaders[textproto.CanonicalMIMEHeaderKey(ProtocolAcceptVersionHeader)] for _, otherAcceptedVersion := range otherAcceptedVersions { // do we have a matching version ? - for _, supportedProtocolVersion := range SupportedProtocolVersions { + for _, supportedProtocolVersion := range wn.supportedProtocolVersions { if supportedProtocolVersion == otherAcceptedVersion { matchingVersion = supportedProtocolVersion return matchingVersion, "" @@ -1012,7 +1030,7 @@ func (wn *WebsocketNetwork) checkProtocolVersionMatch(otherHeaders http.Header) } otherVersion = otherHeaders.Get(ProtocolVersionHeader) - for _, supportedProtocolVersion := range SupportedProtocolVersions { + for _, supportedProtocolVersion := range wn.supportedProtocolVersions { if supportedProtocolVersion == otherVersion { return supportedProtocolVersion, otherVersion } @@ -1097,10 +1115,10 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt matchingVersion, otherVersion := wn.checkProtocolVersionMatch(request.Header) if matchingVersion == "" { - wn.log.Info(filterASCII(fmt.Sprintf("new peer %s version mismatch, mine=%v theirs=%s, headers %#v", request.RemoteAddr, SupportedProtocolVersions, otherVersion, request.Header))) + wn.log.Info(filterASCII(fmt.Sprintf("new peer %s version mismatch, mine=%v theirs=%s, headers %#v", request.RemoteAddr, wn.supportedProtocolVersions, otherVersion, request.Header))) networkConnectionsDroppedTotal.Inc(map[string]string{"reason": "mismatching protocol version"}) response.WriteHeader(http.StatusPreconditionFailed) - message := fmt.Sprintf("Requested version %s not in %v mismatches server version", filterASCII(otherVersion), SupportedProtocolVersions) + message := fmt.Sprintf("Requested version %s not in %v mismatches server version", filterASCII(otherVersion), wn.supportedProtocolVersions) n, err := response.Write([]byte(message)) if err != nil { wn.log.Warnf("ws failed to write response '%s' : n = %d err = %v", message, n, err) @@ -1120,6 +1138,7 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt wn.setHeaders(responseHeader) responseHeader.Set(ProtocolVersionHeader, matchingVersion) responseHeader.Set(GenesisHeader, wn.GenesisID) + responseHeader.Set(PeerFeaturesHeader, PeerFeatureProposalCompression) var challenge string if wn.prioScheme != nil { challenge = wn.prioScheme.NewPrioChallenge() @@ -1146,6 +1165,7 @@ func (wn *WebsocketNetwork) ServeHTTP(response http.ResponseWriter, request *htt prioChallenge: challenge, createTime: trackedRequest.created, version: matchingVersion, + features: decodePeerFeatures(matchingVersion, request.Header.Get(PeerFeaturesHeader)), } peer.TelemetryGUID = trackedRequest.otherTelemetryGUID peer.init(wn.config, wn.outgoingMessagesBufferSize) @@ -1418,23 +1438,21 @@ func (wn *WebsocketNetwork) peerSnapshot(dest []*wsPeer) ([]*wsPeer, int32) { return dest, peerChangeCounter } -// prio is set if the broadcast is a high-priority broadcast. -func (wn *WebsocketNetwork) innerBroadcast(request broadcastRequest, prio bool, peers []*wsPeer) { - if request.done != nil { - defer close(request.done) +// preparePeerData prepares batches of data for sending. +// It performs optional zstd compression for proposal massages +func (wn *WebsocketNetwork) preparePeerData(request broadcastRequest, prio bool, peers []*wsPeer) ([][]byte, [][]byte, []crypto.Digest) { + // determine if there is a payload proposal and peers supporting compressed payloads + wantCompression := false + if prio { + wantCompression = checkCanCompress(request, peers) } - broadcastQueueDuration := time.Now().Sub(request.enqueueTime) - networkBroadcastQueueMicros.AddUint64(uint64(broadcastQueueDuration.Nanoseconds()/1000), nil) - if broadcastQueueDuration > maxMessageQueueDuration { - networkBroadcastsDropped.Inc(nil) - return + digests := make([]crypto.Digest, len(request.data)) + data := make([][]byte, len(request.data)) + var dataCompressed [][]byte + if wantCompression { + dataCompressed = make([][]byte, len(request.data)) } - - start := time.Now() - - digests := make([]crypto.Digest, len(request.data), len(request.data)) - data := make([][]byte, len(request.data), len(request.data)) for i, d := range request.data { tbytes := []byte(request.tags[i]) mbytes := make([]byte, len(tbytes)+len(d)) @@ -1444,8 +1462,45 @@ func (wn *WebsocketNetwork) innerBroadcast(request broadcastRequest, prio bool, if request.tags[i] != protocol.MsgDigestSkipTag && len(d) >= messageFilterSize { digests[i] = crypto.Hash(mbytes) } + + if prio && request.tags[i] == protocol.ProposalPayloadTag { + networkPrioPPNonCompressedSize.Add(float64(len(d)), nil) + } + + if wantCompression { + if request.tags[i] == protocol.ProposalPayloadTag { + compressed, logMsg := zstdCompressMsg(tbytes, d) + if len(logMsg) > 0 { + wn.log.Warn(logMsg) + } else { + networkPrioPPCompressedSize.Add(float64(len(compressed)), nil) + } + dataCompressed[i] = compressed + } else { + // otherwise reuse non-compressed from above + dataCompressed[i] = mbytes + } + } + } + return data, dataCompressed, digests +} + +// prio is set if the broadcast is a high-priority broadcast. +func (wn *WebsocketNetwork) innerBroadcast(request broadcastRequest, prio bool, peers []*wsPeer) { + if request.done != nil { + defer close(request.done) + } + + broadcastQueueDuration := time.Since(request.enqueueTime) + networkBroadcastQueueMicros.AddUint64(uint64(broadcastQueueDuration.Nanoseconds()/1000), nil) + if broadcastQueueDuration > maxMessageQueueDuration { + networkBroadcastsDropped.Inc(nil) + return } + start := time.Now() + data, dataWithCompression, digests := wn.preparePeerData(request, prio, peers) + // first send to all the easy outbound peers who don't block, get them started. sentMessageCount := 0 for _, peer := range peers { @@ -1455,7 +1510,19 @@ func (wn *WebsocketNetwork) innerBroadcast(request broadcastRequest, prio bool, if peer == request.except { continue } - ok := peer.writeNonBlockMsgs(request.ctx, data, prio, digests, request.enqueueTime) + var ok bool + if peer.pfProposalCompressionSupported() && len(dataWithCompression) > 0 { + // if this peer supports compressed proposals and compressed data batch is filled out, use it + ok = peer.writeNonBlockMsgs(request.ctx, dataWithCompression, prio, digests, request.enqueueTime) + if prio { + networkPrioBatchesPPWithCompression.Inc(nil) + } + } else { + ok = peer.writeNonBlockMsgs(request.ctx, data, prio, digests, request.enqueueTime) + if prio { + networkPrioBatchesPPWithoutCompression.Inc(nil) + } + } if ok { sentMessageCount++ continue @@ -1463,7 +1530,7 @@ func (wn *WebsocketNetwork) innerBroadcast(request broadcastRequest, prio bool, networkPeerBroadcastDropped.Inc(nil) } - dt := time.Now().Sub(start) + dt := time.Since(start) networkBroadcasts.Inc(nil) networkBroadcastSendMicros.AddUint64(uint64(dt.Nanoseconds()/1000), nil) } @@ -1838,14 +1905,15 @@ const ProtocolVersionHeader = "X-Algorand-Version" const ProtocolAcceptVersionHeader = "X-Algorand-Accept-Version" // SupportedProtocolVersions contains the list of supported protocol versions by this node ( in order of preference ). -var SupportedProtocolVersions = []string{"2.1"} +var SupportedProtocolVersions = []string{"2.2", "2.1"} // ProtocolVersion is the current version attached to the ProtocolVersionHeader header /* Version history: * 1 Catchup service over websocket connections with unicast messages between peers * 2.1 Introduced topic key/data pairs and enabled services over the gossip connections + * 2.2 Peer features */ -const ProtocolVersion = "2.1" +const ProtocolVersion = "2.2" // TelemetryIDHeader HTTP header for telemetry-id for logging const TelemetryIDHeader = "X-Algorand-TelId" @@ -1871,6 +1939,13 @@ const TooManyRequestsRetryAfterHeader = "Retry-After" // UserAgentHeader is the HTTP header identify the user agent. const UserAgentHeader = "User-Agent" +// PeerFeaturesHeader is the HTTP header listing features +const PeerFeaturesHeader = "X-Algorand-Peer-Features" + +// PeerFeatureProposalCompression is a value for PeerFeaturesHeader indicating peer +// supports proposal payload compression with zstd +const PeerFeatureProposalCompression = "ppzstd" + var websocketsScheme = map[string]string{"http": "ws", "https": "wss"} var errBadAddr = errors.New("bad address") @@ -2011,11 +2086,13 @@ func (wn *WebsocketNetwork) tryConnect(addr, gossipAddr string) { defer wn.wg.Done() requestHeader := make(http.Header) wn.setHeaders(requestHeader) - for _, supportedProtocolVersion := range SupportedProtocolVersions { + for _, supportedProtocolVersion := range wn.supportedProtocolVersions { requestHeader.Add(ProtocolAcceptVersionHeader, supportedProtocolVersion) } // for backward compatibility, include the ProtocolVersion header as well. - requestHeader.Set(ProtocolVersionHeader, ProtocolVersion) + requestHeader.Set(ProtocolVersionHeader, wn.protocolVersion) + // set the features header (comma-separated list) + requestHeader.Set(PeerFeaturesHeader, PeerFeatureProposalCompression) SetUserAgentHeader(requestHeader) myInstanceName := wn.log.GetInstanceName() requestHeader.Set(InstanceNameHeader, myInstanceName) @@ -2030,7 +2107,7 @@ func (wn *WebsocketNetwork) tryConnect(addr, gossipAddr string) { conn, response, err := websocketDialer.DialContext(wn.ctx, gossipAddr, requestHeader) if err != nil { if err == websocket.ErrBadHandshake { - // reading here from ioutil is safe only because it came from DialContext above, which alredy finsihed reading all the data from the network + // reading here from ioutil is safe only because it came from DialContext above, which already finished reading all the data from the network // and placed it all in a ioutil.NopCloser reader. bodyBytes, _ := io.ReadAll(response.Body) errString := string(bodyBytes) @@ -2087,6 +2164,7 @@ func (wn *WebsocketNetwork) tryConnect(addr, gossipAddr string) { connMonitor: wn.connPerfMonitor, throttledOutgoingConnection: throttledConnection, version: matchingVersion, + features: decodePeerFeatures(matchingVersion, response.Header.Get(PeerFeaturesHeader)), } peer.TelemetryGUID, peer.InstanceName, _ = getCommonHeaders(response.Header) peer.init(wn.config, wn.outgoingMessagesBufferSize) diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index dc86511254..37dd646aa3 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -17,6 +17,7 @@ package network import ( + "bytes" "context" "encoding/binary" "fmt" @@ -200,6 +201,47 @@ func newMessageCounter(t testing.TB, target int) *messageCounterHandler { return &messageCounterHandler{target: target, done: make(chan struct{}), t: t} } +type messageMatcherHandler struct { + lock deadlock.Mutex + + target [][]byte + received [][]byte + done chan struct{} +} + +func (mmh *messageMatcherHandler) Handle(message IncomingMessage) OutgoingMessage { + mmh.lock.Lock() + defer mmh.lock.Unlock() + + mmh.received = append(mmh.received, message.Data) + if len(mmh.target) > 0 && mmh.done != nil && len(mmh.received) >= len(mmh.target) { + close(mmh.done) + mmh.done = nil + } + + return OutgoingMessage{Action: Ignore} +} + +func (mmh *messageMatcherHandler) Match() bool { + if len(mmh.target) != len(mmh.received) { + return false + } + + sort.Slice(mmh.target, func(i, j int) bool { return bytes.Compare(mmh.target[i], mmh.target[j]) == -1 }) + sort.Slice(mmh.received, func(i, j int) bool { return bytes.Compare(mmh.received[i], mmh.received[j]) == -1 }) + + for i := 0; i < len(mmh.target); i++ { + if !bytes.Equal(mmh.target[i], mmh.received[i]) { + return false + } + } + return true +} + +func newMessageMatcher(t testing.TB, target [][]byte) *messageMatcherHandler { + return &messageMatcherHandler{target: target, done: make(chan struct{})} +} + func TestWebsocketNetworkStartStop(t *testing.T) { partitiontest.PartitionTest(t) @@ -262,6 +304,82 @@ func TestWebsocketNetworkBasic(t *testing.T) { } } +// Set up two nodes, send proposal +func TestWebsocketProposalPayloadCompression(t *testing.T) { + partitiontest.PartitionTest(t) + + type testDef struct { + netASupProto []string + netAProto string + netBSupProto []string + netBProto string + } + + var tests []testDef = []testDef{ + // two old nodes + {[]string{"2.1"}, "2.1", []string{"2.1"}, "2.1"}, + + // two new nodes with overwritten config + {[]string{"2.2"}, "2.2", []string{"2.2"}, "2.2"}, + + // old node + new node + {[]string{"2.1"}, "2.1", []string{"2.2", "2.1"}, "2.2"}, + {[]string{"2.2", "2.1"}, "2.2", []string{"2.1"}, "2.1"}, + + // combinations + {[]string{"2.2", "2.1"}, "2.1", []string{"2.2", "2.1"}, "2.1"}, + {[]string{"2.2", "2.1"}, "2.2", []string{"2.2", "2.1"}, "2.1"}, + {[]string{"2.2", "2.1"}, "2.1", []string{"2.2", "2.1"}, "2.2"}, + {[]string{"2.2", "2.1"}, "2.2", []string{"2.2", "2.1"}, "2.2"}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("A_%s(%s)+B_%s(%s)", test.netASupProto, test.netAProto, test.netBSupProto, test.netBProto), func(t *testing.T) { + netA := makeTestWebsocketNode(t) + netA.config.GossipFanout = 1 + netA.protocolVersion = test.netAProto + netA.supportedProtocolVersions = test.netASupProto + netA.Start() + defer netStop(t, netA, "A") + netB := makeTestWebsocketNode(t) + netB.config.GossipFanout = 1 + netB.protocolVersion = test.netBProto + netA.supportedProtocolVersions = test.netBSupProto + addrA, postListen := netA.Address() + require.True(t, postListen) + t.Log(addrA) + netB.phonebook.ReplacePeerList([]string{addrA}, "default", PhoneBookEntryRelayRole) + netB.Start() + defer netStop(t, netB, "B") + messages := [][]byte{ + []byte("foo"), + []byte("bar"), + } + matcher := newMessageMatcher(t, messages) + counterDone := matcher.done + netB.RegisterHandlers([]TaggedMessageHandler{{Tag: protocol.ProposalPayloadTag, MessageHandler: matcher}}) + + readyTimeout := time.NewTimer(2 * time.Second) + waitReady(t, netA, readyTimeout.C) + t.Log("a ready") + waitReady(t, netB, readyTimeout.C) + t.Log("b ready") + + for _, msg := range messages { + netA.Broadcast(context.Background(), protocol.ProposalPayloadTag, msg, false, nil) + } + + select { + case <-counterDone: + case <-time.After(2 * time.Second): + t.Errorf("timeout, count=%d, wanted %d", len(matcher.received), len(messages)) + } + + require.True(t, matcher.Match()) + }) + } +} + // Repeat basic, but test a unicast func TestWebsocketNetworkUnicast(t *testing.T) { partitiontest.PartitionTest(t) @@ -1594,11 +1712,6 @@ func TestSetUserAgentHeader(t *testing.T) { func TestCheckProtocolVersionMatch(t *testing.T) { partitiontest.PartitionTest(t) - // note - this test changes the SupportedProtocolVersions global variable ( SupportedProtocolVersions ) and therefore cannot be parallelized. - originalSupportedProtocolVersions := SupportedProtocolVersions - defer func() { - SupportedProtocolVersions = originalSupportedProtocolVersions - }() log := logging.TestingLog(t) log.SetLevel(logging.Level(defaultConfig.BaseLoggerDebugLevel)) wn := &WebsocketNetwork{ @@ -1609,8 +1722,7 @@ func TestCheckProtocolVersionMatch(t *testing.T) { NetworkID: config.Devtestnet, } wn.setup() - - SupportedProtocolVersions = []string{"2", "1"} + wn.supportedProtocolVersions = []string{"2", "1"} header1 := make(http.Header) header1.Add(ProtocolAcceptVersionHeader, "1") @@ -2574,3 +2686,56 @@ func TestParseHostOrURL(t *testing.T) { }) } } + +func TestPreparePeerData(t *testing.T) { + partitiontest.PartitionTest(t) + + // no comression + req := broadcastRequest{ + tags: []protocol.Tag{protocol.AgreementVoteTag, protocol.ProposalPayloadTag}, + data: [][]byte{[]byte("test"), []byte("data")}, + } + + peers := []*wsPeer{} + wn := WebsocketNetwork{} + data, comp, digests := wn.preparePeerData(req, false, peers) + require.NotEmpty(t, data) + require.Empty(t, comp) + require.NotEmpty(t, digests) + require.Equal(t, len(req.data), len(digests)) + require.Equal(t, len(data), len(digests)) + + for i := range data { + require.Equal(t, append([]byte(req.tags[i]), req.data[i]...), data[i]) + } + + // compression + peer1 := wsPeer{ + features: 0, + } + peer2 := wsPeer{ + features: pfCompressedProposal, + } + peers = []*wsPeer{&peer1, &peer2} + data, comp, digests = wn.preparePeerData(req, true, peers) + require.NotEmpty(t, data) + require.NotEmpty(t, comp) + require.NotEmpty(t, digests) + require.Equal(t, len(req.data), len(digests)) + require.Equal(t, len(data), len(digests)) + require.Equal(t, len(comp), len(digests)) + + for i := range data { + require.Equal(t, append([]byte(req.tags[i]), req.data[i]...), data[i]) + } + + for i := range comp { + if req.tags[i] != protocol.ProposalPayloadTag { + require.Equal(t, append([]byte(req.tags[i]), req.data[i]...), comp[i]) + require.Equal(t, data[i], comp[i]) + } else { + require.NotEqual(t, data[i], comp[i]) + require.Equal(t, append([]byte(req.tags[i]), zstdCompressionMagic[:]...), comp[i][:len(req.tags[i])+len(zstdCompressionMagic)]) + } + } +} diff --git a/network/wsPeer.go b/network/wsPeer.go index 870eefddbd..b163455767 100644 --- a/network/wsPeer.go +++ b/network/wsPeer.go @@ -24,6 +24,8 @@ import ( "net" "net/http" "runtime" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -58,6 +60,22 @@ func init() { networkReceivedBytesByTag = metrics.NewTagCounterFiltered("algod_network_received_bytes_{TAG}", "Number of bytes that were received from the network for {TAG} messages", tagStringList, "UNK") networkMessageReceivedByTag = metrics.NewTagCounterFiltered("algod_network_message_received_{TAG}", "Number of complete messages that were received from the network for {TAG} messages", tagStringList, "UNK") networkMessageSentByTag = metrics.NewTagCounterFiltered("algod_network_message_sent_{TAG}", "Number of complete messages that were sent to the network for {TAG} messages", tagStringList, "UNK") + + matched := false + for _, version := range SupportedProtocolVersions { + if version == versionPeerFeatures { + matched = true + } + } + if !matched { + panic(fmt.Sprintf("peer features version %s is not supported %v", versionPeerFeatures, SupportedProtocolVersions)) + } + + var err error + versionPeerFeaturesNum[0], versionPeerFeaturesNum[1], err = versionToMajorMinor(versionPeerFeatures) + if err != nil { + panic(fmt.Sprintf("failed to parse version %v: %s", versionPeerFeatures, err.Error())) + } } var networkSentBytesTotal = metrics.MakeCounter(metrics.NetworkSentBytesTotal) @@ -213,6 +231,9 @@ type wsPeer struct { // peer version ( this is one of the version supported by the current node and listed in SupportedProtocolVersions ) version string + // peer features derived from the peer version + features peerFeatureFlag + // responseChannels used by the client to wait on the response of the request responseChannels map[uint64]chan *Response @@ -220,10 +241,10 @@ type wsPeer struct { responseChannelsMutex deadlock.RWMutex // sendMessageTag is a map of allowed message to send to a peer. We don't use any synchronization on this map, and the - // only gurentee is that it's being accessed only during startup and/or by the sending loop go routine. + // only guarantee is that it's being accessed only during startup and/or by the sending loop go routine. sendMessageTag map[protocol.Tag]bool - // messagesOfInterestGeneration is this node's messagesOfInterest version that we have seent to this peer. + // messagesOfInterestGeneration is this node's messagesOfInterest version that we have seen to this peer. messagesOfInterestGeneration uint32 // connMonitor used to measure the relative performance of the connection @@ -231,10 +252,10 @@ type wsPeer struct { // field set to nil. connMonitor *connectionPerformanceMonitor - // peerMessageDelay is calculated by the connection monitor; it's the relative avarage per-message delay. + // peerMessageDelay is calculated by the connection monitor; it's the relative average per-message delay. peerMessageDelay int64 - // throttledOutgoingConnection determines if this outgoing connection will be throttled bassed on it's + // throttledOutgoingConnection determines if this outgoing connection will be throttled based on it's // performance or not. Throttled connections are more likely to be short-lived connections. throttledOutgoingConnection bool @@ -405,6 +426,8 @@ func (wp *wsPeer) readLoop() { }() wp.conn.SetReadLimit(maxMessageLength) slurper := MakeLimitedReaderSlurper(averageMessageLength, maxMessageLength) + dataConverter := makeWsPeerMsgDataConverter(wp) + for { msg := IncomingMessage{} mtype, reader, err := wp.conn.NextReader() @@ -444,6 +467,11 @@ func (wp *wsPeer) readLoop() { msg.processing = wp.processed msg.Received = time.Now().UnixNano() msg.Data = slurper.Bytes() + msg.Data, err = dataConverter.convert(msg.Tag, msg.Data) + if err != nil { + wp.reportReadErr(err) + return + } msg.Net = wp.net atomic.StoreInt64(&wp.lastPacketTime, msg.Received) networkReceivedBytesTotal.AddUint64(uint64(len(msg.Data)+2), nil) @@ -913,3 +941,57 @@ func (wp *wsPeer) sendMessagesOfInterest(messagesOfInterestGeneration uint32, me atomic.StoreUint32(&wp.messagesOfInterestGeneration, messagesOfInterestGeneration) } } + +func (wp *wsPeer) pfProposalCompressionSupported() bool { + return wp.features&pfCompressedProposal != 0 +} + +type peerFeatureFlag int + +const pfCompressedProposal peerFeatureFlag = 1 + +// versionPeerFeatures defines protocol version when peer features were introduced +const versionPeerFeatures = "2.2" + +// versionPeerFeaturesNum is a parsed numeric representation of versionPeerFeatures +var versionPeerFeaturesNum [2]int64 + +func versionToMajorMinor(version string) (int64, int64, error) { + parts := strings.Split(version, ".") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("version %s does not have two components", version) + } + major, err := strconv.ParseInt(parts[0], 10, 8) + if err != nil { + return 0, 0, err + } + minor, err := strconv.ParseInt(parts[1], 10, 8) + if err != nil { + return 0, 0, err + } + return major, minor, nil +} + +func decodePeerFeatures(version string, announcedFeatures string) peerFeatureFlag { + major, minor, err := versionToMajorMinor(version) + if err != nil { + return 0 + } + + if major < versionPeerFeaturesNum[0] { + return 0 + } + if minor < versionPeerFeaturesNum[1] { + return 0 + } + + var features peerFeatureFlag + parts := strings.Split(announcedFeatures, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == PeerFeatureProposalCompression { + features |= pfCompressedProposal + } + } + return features +} diff --git a/network/wsPeer_test.go b/network/wsPeer_test.go index 800ab5b148..d61c182d32 100644 --- a/network/wsPeer_test.go +++ b/network/wsPeer_test.go @@ -18,6 +18,7 @@ package network import ( "encoding/binary" + "fmt" "strings" "testing" "time" @@ -125,3 +126,56 @@ func TestTagCounterFiltering(t *testing.T) { }) } } + +func TestVersionToMajorMinor(t *testing.T) { + partitiontest.PartitionTest(t) + + ma, mi, err := versionToMajorMinor("1.2") + require.NoError(t, err) + require.Equal(t, int64(1), ma) + require.Equal(t, int64(2), mi) + + ma, mi, err = versionToMajorMinor("1.2.3") + require.Error(t, err) + require.Zero(t, ma) + require.Zero(t, mi) + + ma, mi, err = versionToMajorMinor("1") + require.Error(t, err) + require.Zero(t, ma) + require.Zero(t, mi) + + ma, mi, err = versionToMajorMinor("a.b") + require.Error(t, err) + require.Zero(t, ma) + require.Zero(t, mi) +} + +func TestVersionToFeature(t *testing.T) { + partitiontest.PartitionTest(t) + + tests := []struct { + ver string + hdr string + expected peerFeatureFlag + }{ + {"1.2", "", peerFeatureFlag(0)}, + {"1.2.3", "", peerFeatureFlag(0)}, + {"a.b", "", peerFeatureFlag(0)}, + {"2.1", "", peerFeatureFlag(0)}, + {"2.1", PeerFeatureProposalCompression, peerFeatureFlag(0)}, + {"2.2", "", peerFeatureFlag(0)}, + {"2.2", "test", peerFeatureFlag(0)}, + {"2.2", strings.Join([]string{"a", "b"}, ","), peerFeatureFlag(0)}, + {"2.2", PeerFeatureProposalCompression, pfCompressedProposal}, + {"2.2", strings.Join([]string{PeerFeatureProposalCompression, "test"}, ","), pfCompressedProposal}, + {"2.2", strings.Join([]string{PeerFeatureProposalCompression, "test"}, ", "), pfCompressedProposal}, + {"2.3", PeerFeatureProposalCompression, pfCompressedProposal}, + } + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + f := decodePeerFeatures(test.ver, test.hdr) + require.Equal(t, test.expected, f) + }) + } +}