diff --git a/itests/clients/net_client.go b/itests/clients/net_client.go index a81a549ec..e48967bbe 100644 --- a/itests/clients/net_client.go +++ b/itests/clients/net_client.go @@ -4,9 +4,13 @@ import ( "bytes" "context" "encoding/base64" + "errors" + "fmt" "io" "log/slog" + "math/big" "net" + "reflect" "sync" "sync/atomic" "testing" @@ -33,6 +37,7 @@ type NetClient struct { impl Implementation n *networking.Network c *networking.Config + h *handler s *networking.Session closing atomic.Bool @@ -68,7 +73,7 @@ func NewNetClient( s, err := n.NewSession(ctx, conn, conf) require.NoError(t, err, "failed to establish new session to %s node", impl.String()) - cli := &NetClient{ctx: ctx, t: t, impl: impl, n: n, c: conf, s: s} + cli := &NetClient{ctx: ctx, t: t, impl: impl, n: n, c: conf, h: h, s: s} h.client = cli // Set client reference in handler. return cli } @@ -106,6 +111,78 @@ func (c *NetClient) Close() { }) } +// SubscribeForMessages adds specified types to the message waiting queue. +// Once the awaited message received the corresponding type is removed from the queue. +func (c *NetClient) SubscribeForMessages(messageType ...reflect.Type) error { + for _, mt := range messageType { + if err := c.h.waitFor(mt); err != nil { + return err + } + } + return nil +} + +// AwaitMessage waits for a message from the node for the specified timeout. +func (c *NetClient) AwaitMessage(messageType reflect.Type, timeout time.Duration) (proto.Message, error) { + select { + case <-c.ctx.Done(): + return nil, c.ctx.Err() + case <-time.After(timeout): + return nil, errors.New("timeout waiting for message") + case msg := <-c.h.receiveChan(): + if reflect.TypeOf(msg) != messageType { + return nil, fmt.Errorf("unexpected message type %q", reflect.TypeOf(msg).String()) + } + return msg, nil + } +} + +// AwaitGetBlockMessage waits for a GetBlockMessage from the node for the specified timeout and +// returns the requested block ID. +func (c *NetClient) AwaitGetBlockMessage(timeout time.Duration) (proto.BlockID, error) { + msg, err := c.AwaitMessage(reflect.TypeOf(&proto.GetBlockMessage{}), timeout) + if err != nil { + return proto.BlockID{}, err + } + getBlockMessage, ok := msg.(*proto.GetBlockMessage) + if !ok { + return proto.BlockID{}, errors.New("unexpected message type") + } + return getBlockMessage.BlockID, nil +} + +// AwaitScoreMessage waits for a ScoreMessage from the node for the specified timeout and returns the received score. +func (c *NetClient) AwaitScoreMessage(timeout time.Duration) (*big.Int, error) { + msg, err := c.AwaitMessage(reflect.TypeOf(&proto.ScoreMessage{}), timeout) + if err != nil { + return nil, err + } + scoreMessage, ok := msg.(*proto.ScoreMessage) + if !ok { + return nil, errors.New("unexpected message type") + } + score := new(big.Int).SetBytes(scoreMessage.Score) + return score, nil +} + +// AwaitMicroblockRequest waits for a MicroBlockRequestMessage from the node for the specified timeout and +// returns the received block ID. +func (c *NetClient) AwaitMicroblockRequest(timeout time.Duration) (proto.BlockID, error) { + msg, err := c.AwaitMessage(reflect.TypeOf(&proto.MicroBlockRequestMessage{}), timeout) + if err != nil { + return proto.BlockID{}, err + } + mbr, ok := msg.(*proto.MicroBlockRequestMessage) + if !ok { + return proto.BlockID{}, errors.New("unexpected message type") + } + r, err := proto.NewBlockIDFromBytes(mbr.TotalBlockSig) + if err != nil { + return proto.BlockID{}, err + } + return r, nil +} + func (c *NetClient) reconnect() { c.t.Logf("Reconnecting to %q", c.s.RemoteAddr().String()) conn, err := net.Dial("tcp", c.s.RemoteAddr().String()) @@ -184,10 +261,13 @@ type handler struct { peers []proto.PeerInfo t testing.TB client *NetClient + queue []reflect.Type + ch chan proto.Message } func newHandler(t testing.TB, peers []proto.PeerInfo) *handler { - return &handler{t: t, peers: peers} + ch := make(chan proto.Message, 1) + return &handler{t: t, peers: peers, ch: ch} } func (h *handler) OnReceive(s *networking.Session, data []byte) { @@ -206,6 +286,15 @@ func (h *handler) OnReceive(s *networking.Session, data []byte) { return } default: + if len(h.queue) == 0 { // No messages to wait for. + return + } + et := h.queue[0] + if reflect.TypeOf(msg) == et { + h.t.Logf("Received expected message of type %q", reflect.TypeOf(msg).String()) + h.queue = h.queue[1:] // Pop the expected type. + h.ch <- msg + } } } @@ -219,3 +308,18 @@ func (h *handler) OnClose(s *networking.Session) { h.client.reconnect() } } + +func (h *handler) waitFor(messageType reflect.Type) error { + if messageType == nil { + return errors.New("nil message type") + } + if messageType == reflect.TypeOf(proto.GetPeersMessage{}) { + return errors.New("cannot wait for GetPeersMessage") + } + h.queue = append(h.queue, messageType) + return nil +} + +func (h *handler) receiveChan() <-chan proto.Message { + return h.ch +} diff --git a/itests/snapshot_internal_test.go b/itests/snapshot_internal_test.go index 454ab468e..f78e216e4 100644 --- a/itests/snapshot_internal_test.go +++ b/itests/snapshot_internal_test.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "math" "math/big" + "reflect" "testing" "time" @@ -33,6 +34,8 @@ func (s *SimpleSnapshotSuite) SetupSuite() { } func (s *SimpleSnapshotSuite) TestSimpleSnapshot() { + const messageTimeout = 5 * time.Second + acc := s.Cfg.GetRichestAccount() // Initialize genesis block ID. @@ -58,25 +61,46 @@ func (s *SimpleSnapshotSuite) TestSimpleSnapshot() { time.Sleep(delay) } + err = s.Client.Connection.SubscribeForMessages( + reflect.TypeOf(&proto.GetBlockIdsMessage{}), + reflect.TypeOf(&proto.GetBlockMessage{}), + reflect.TypeOf(&proto.ScoreMessage{}), + reflect.TypeOf(&proto.MicroBlockRequestMessage{}), + ) + require.NoError(s.T(), err, "failed to subscribe for messages") + // Calculate new score and send score to the node. genesisScore := calculateScore(s.Cfg.BlockchainSettings.Genesis.BaseTarget) blockScore := calculateCumulativeScore(genesisScore, bl.BaseTarget) scoreMsg := &proto.ScoreMessage{Score: blockScore.Bytes()} s.Client.Connection.SendMessage(scoreMsg) - time.Sleep(100 * time.Millisecond) + + // Wait for the node to request block IDs. + _, err = s.Client.Connection.AwaitMessage(reflect.TypeOf(&proto.GetBlockIdsMessage{}), messageTimeout) + require.NoError(s.T(), err, "failed to wait for block IDs request") // Send block IDs to the node. blocksMsg := &proto.BlockIdsMessage{Blocks: []proto.BlockID{bl.BlockID()}} s.Client.Connection.SendMessage(blocksMsg) time.Sleep(100 * time.Millisecond) + // Wait for the node to request the block. + blockID, err := s.Client.Connection.AwaitGetBlockMessage(messageTimeout) + require.NoError(s.T(), err, "failed to wait for block request") + assert.Equal(s.T(), bl.BlockID(), blockID) + // Marshal the block and send it to the node. bb, err := bl.MarshalToProtobuf(s.Cfg.BlockchainSettings.AddressSchemeCharacter) require.NoError(s.T(), err, "failed to marshal block") blMsg := &proto.PBBlockMessage{PBBlockBytes: bb} s.Client.Connection.SendMessage(blMsg) - // Wait for 2.5 seconds and send mb-block. + // Wait for updated score message. + score, err := s.Client.Connection.AwaitScoreMessage(messageTimeout) + require.NoError(s.T(), err, "failed to wait for score") + assert.Equal(s.T(), blockScore, score) + + // Wait for 2.5 seconds and send micro-block. time.Sleep(2500 * time.Millisecond) // Add transactions to block. @@ -89,14 +113,18 @@ func (s *SimpleSnapshotSuite) TestSimpleSnapshot() { // Create micro-block with the transaction and unchanged state hash. mb, inv := createMicroBlockAndInv(s.T(), *bl, s.Cfg.BlockchainSettings, tx, acc.SecretKey, acc.PublicKey, sh) + // Send micro-block inv to the node. ib, err := inv.MarshalBinary() require.NoError(s.T(), err, "failed to marshal inv") - invMsg := &proto.MicroBlockInvMessage{Body: ib} - time.Sleep(100 * time.Millisecond) s.Client.Connection.SendMessage(invMsg) - time.Sleep(100 * time.Millisecond) + // Wait for the node to request micro-block. + mbID, err := s.Client.Connection.AwaitMicroblockRequest(messageTimeout) + require.NoError(s.T(), err, "failed to wait for micro-block request") + assert.Equal(s.T(), inv.TotalBlockID, mbID) + + // Marshal the micro-block and send it to the node. mbb, err := mb.MarshalToProtobuf(s.Cfg.BlockchainSettings.AddressSchemeCharacter) require.NoError(s.T(), err, "failed to marshal micro block") mbMsg := &proto.PBMicroBlockMessage{MicroBlockBytes: mbb}