Skip to content

Commit

Permalink
Functions to subscribe and wait for network messages of specifed type…
Browse files Browse the repository at this point in the history
…s added to itest network client.

SimpleSnapshot test reimplemented using assertions for expected messages instead of sleeping for some time.
  • Loading branch information
alexeykiselev committed Dec 26, 2024
1 parent 227698f commit 7a27d43
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 7 deletions.
108 changes: 106 additions & 2 deletions itests/clients/net_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"math/big"
"net"
"reflect"
"sync"
"sync/atomic"
"testing"
Expand All @@ -33,6 +37,7 @@ type NetClient struct {
impl Implementation
n *networking.Network
c *networking.Config
h *handler
s *networking.Session

closing atomic.Bool
Expand Down Expand Up @@ -68,7 +73,7 @@ func NewNetClient(
s, err := n.NewSession(ctx, conn, conf)
require.NoError(t, err, "failed to establish new session to %s node", impl.String())

cli := &NetClient{ctx: ctx, t: t, impl: impl, n: n, c: conf, s: s}
cli := &NetClient{ctx: ctx, t: t, impl: impl, n: n, c: conf, h: h, s: s}
h.client = cli // Set client reference in handler.
return cli
}
Expand Down Expand Up @@ -106,6 +111,78 @@ func (c *NetClient) Close() {
})
}

// SubscribeForMessages adds specified types to the message waiting queue.
// Once the awaited message received the corresponding type is removed from the queue.
func (c *NetClient) SubscribeForMessages(messageType ...reflect.Type) error {
for _, mt := range messageType {
if err := c.h.waitFor(mt); err != nil {
return err
}
}
return nil
}

// AwaitMessage waits for a message from the node for the specified timeout.
func (c *NetClient) AwaitMessage(messageType reflect.Type, timeout time.Duration) (proto.Message, error) {
select {
case <-c.ctx.Done():
return nil, c.ctx.Err()
case <-time.After(timeout):
return nil, errors.New("timeout waiting for message")
case msg := <-c.h.receiveChan():
if reflect.TypeOf(msg) != messageType {
return nil, fmt.Errorf("unexpected message type %q", reflect.TypeOf(msg).String())
}
return msg, nil
}
}

// AwaitGetBlockMessage waits for a GetBlockMessage from the node for the specified timeout and
// returns the requested block ID.
func (c *NetClient) AwaitGetBlockMessage(timeout time.Duration) (proto.BlockID, error) {
msg, err := c.AwaitMessage(reflect.TypeOf(&proto.GetBlockMessage{}), timeout)
if err != nil {
return proto.BlockID{}, err
}
getBlockMessage, ok := msg.(*proto.GetBlockMessage)
if !ok {
return proto.BlockID{}, errors.New("unexpected message type")
}
return getBlockMessage.BlockID, nil
}

// AwaitScoreMessage waits for a ScoreMessage from the node for the specified timeout and returns the received score.
func (c *NetClient) AwaitScoreMessage(timeout time.Duration) (*big.Int, error) {
msg, err := c.AwaitMessage(reflect.TypeOf(&proto.ScoreMessage{}), timeout)
if err != nil {
return nil, err
}
scoreMessage, ok := msg.(*proto.ScoreMessage)
if !ok {
return nil, errors.New("unexpected message type")
}
score := new(big.Int).SetBytes(scoreMessage.Score)
return score, nil
}

// AwaitMicroblockRequest waits for a MicroBlockRequestMessage from the node for the specified timeout and
// returns the received block ID.
func (c *NetClient) AwaitMicroblockRequest(timeout time.Duration) (proto.BlockID, error) {
msg, err := c.AwaitMessage(reflect.TypeOf(&proto.MicroBlockRequestMessage{}), timeout)
if err != nil {
return proto.BlockID{}, err
}
mbr, ok := msg.(*proto.MicroBlockRequestMessage)
if !ok {
return proto.BlockID{}, errors.New("unexpected message type")
}
r, err := proto.NewBlockIDFromBytes(mbr.TotalBlockSig)
if err != nil {
return proto.BlockID{}, err
}
return r, nil
}

func (c *NetClient) reconnect() {
c.t.Logf("Reconnecting to %q", c.s.RemoteAddr().String())
conn, err := net.Dial("tcp", c.s.RemoteAddr().String())
Expand Down Expand Up @@ -184,10 +261,13 @@ type handler struct {
peers []proto.PeerInfo
t testing.TB
client *NetClient
queue []reflect.Type
ch chan proto.Message
}

func newHandler(t testing.TB, peers []proto.PeerInfo) *handler {
return &handler{t: t, peers: peers}
ch := make(chan proto.Message, 1)
return &handler{t: t, peers: peers, ch: ch}
}

func (h *handler) OnReceive(s *networking.Session, data []byte) {
Expand All @@ -206,6 +286,15 @@ func (h *handler) OnReceive(s *networking.Session, data []byte) {
return
}
default:
if len(h.queue) == 0 { // No messages to wait for.
return
}
et := h.queue[0]
if reflect.TypeOf(msg) == et {
h.t.Logf("Received expected message of type %q", reflect.TypeOf(msg).String())
h.queue = h.queue[1:] // Pop the expected type.
h.ch <- msg
}
}
}

Expand All @@ -219,3 +308,18 @@ func (h *handler) OnClose(s *networking.Session) {
h.client.reconnect()
}
}

func (h *handler) waitFor(messageType reflect.Type) error {
if messageType == nil {
return errors.New("nil message type")
}
if messageType == reflect.TypeOf(proto.GetPeersMessage{}) {
return errors.New("cannot wait for GetPeersMessage")
}
h.queue = append(h.queue, messageType)
return nil
}

func (h *handler) receiveChan() <-chan proto.Message {
return h.ch
}
38 changes: 33 additions & 5 deletions itests/snapshot_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/binary"
"math"
"math/big"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -33,6 +34,8 @@ func (s *SimpleSnapshotSuite) SetupSuite() {
}

func (s *SimpleSnapshotSuite) TestSimpleSnapshot() {
const messageTimeout = 5 * time.Second

acc := s.Cfg.GetRichestAccount()

// Initialize genesis block ID.
Expand All @@ -58,25 +61,46 @@ func (s *SimpleSnapshotSuite) TestSimpleSnapshot() {
time.Sleep(delay)
}

err = s.Client.Connection.SubscribeForMessages(
reflect.TypeOf(&proto.GetBlockIdsMessage{}),
reflect.TypeOf(&proto.GetBlockMessage{}),
reflect.TypeOf(&proto.ScoreMessage{}),
reflect.TypeOf(&proto.MicroBlockRequestMessage{}),
)
require.NoError(s.T(), err, "failed to subscribe for messages")

// Calculate new score and send score to the node.
genesisScore := calculateScore(s.Cfg.BlockchainSettings.Genesis.BaseTarget)
blockScore := calculateCumulativeScore(genesisScore, bl.BaseTarget)
scoreMsg := &proto.ScoreMessage{Score: blockScore.Bytes()}
s.Client.Connection.SendMessage(scoreMsg)
time.Sleep(100 * time.Millisecond)

// Wait for the node to request block IDs.
_, err = s.Client.Connection.AwaitMessage(reflect.TypeOf(&proto.GetBlockIdsMessage{}), messageTimeout)
require.NoError(s.T(), err, "failed to wait for block IDs request")

// Send block IDs to the node.
blocksMsg := &proto.BlockIdsMessage{Blocks: []proto.BlockID{bl.BlockID()}}
s.Client.Connection.SendMessage(blocksMsg)
time.Sleep(100 * time.Millisecond)

// Wait for the node to request the block.
blockID, err := s.Client.Connection.AwaitGetBlockMessage(messageTimeout)
require.NoError(s.T(), err, "failed to wait for block request")
assert.Equal(s.T(), bl.BlockID(), blockID)

// Marshal the block and send it to the node.
bb, err := bl.MarshalToProtobuf(s.Cfg.BlockchainSettings.AddressSchemeCharacter)
require.NoError(s.T(), err, "failed to marshal block")
blMsg := &proto.PBBlockMessage{PBBlockBytes: bb}
s.Client.Connection.SendMessage(blMsg)

// Wait for 2.5 seconds and send mb-block.
// Wait for updated score message.
score, err := s.Client.Connection.AwaitScoreMessage(messageTimeout)
require.NoError(s.T(), err, "failed to wait for score")
assert.Equal(s.T(), blockScore, score)

// Wait for 2.5 seconds and send micro-block.
time.Sleep(2500 * time.Millisecond)

// Add transactions to block.
Expand All @@ -89,14 +113,18 @@ func (s *SimpleSnapshotSuite) TestSimpleSnapshot() {
// Create micro-block with the transaction and unchanged state hash.
mb, inv := createMicroBlockAndInv(s.T(), *bl, s.Cfg.BlockchainSettings, tx, acc.SecretKey, acc.PublicKey, sh)

// Send micro-block inv to the node.
ib, err := inv.MarshalBinary()
require.NoError(s.T(), err, "failed to marshal inv")

invMsg := &proto.MicroBlockInvMessage{Body: ib}
time.Sleep(100 * time.Millisecond)
s.Client.Connection.SendMessage(invMsg)

time.Sleep(100 * time.Millisecond)
// Wait for the node to request micro-block.
mbID, err := s.Client.Connection.AwaitMicroblockRequest(messageTimeout)
require.NoError(s.T(), err, "failed to wait for micro-block request")
assert.Equal(s.T(), inv.TotalBlockID, mbID)

// Marshal the micro-block and send it to the node.
mbb, err := mb.MarshalToProtobuf(s.Cfg.BlockchainSettings.AddressSchemeCharacter)
require.NoError(s.T(), err, "failed to marshal micro block")
mbMsg := &proto.PBMicroBlockMessage{MicroBlockBytes: mbb}
Expand Down

0 comments on commit 7a27d43

Please sign in to comment.