Skip to content

Commit

Permalink
Fix unreadable message errors (ava-labs#1585)
Browse files Browse the repository at this point in the history
  • Loading branch information
morrisettathena authored Jun 8, 2023
1 parent cdf86ae commit 400dd66
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 6 deletions.
130 changes: 129 additions & 1 deletion message/internal_msg_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package message

import (
"fmt"
"time"

"github.com/ava-labs/avalanchego/ids"
Expand All @@ -15,56 +16,80 @@ import (

var (
disconnected = &Disconnected{}
timeout = &Timeout{}
gossipRequest = &GossipRequest{}
timeout = &Timeout{}

_ fmt.Stringer = (*GetStateSummaryFrontierFailed)(nil)
_ chainIDGetter = (*GetStateSummaryFrontierFailed)(nil)
_ requestIDGetter = (*GetStateSummaryFrontierFailed)(nil)

_ fmt.Stringer = (*GetAcceptedStateSummaryFailed)(nil)
_ chainIDGetter = (*GetAcceptedStateSummaryFailed)(nil)
_ requestIDGetter = (*GetAcceptedStateSummaryFailed)(nil)

_ fmt.Stringer = (*GetAcceptedFrontierFailed)(nil)
_ chainIDGetter = (*GetAcceptedFrontierFailed)(nil)
_ requestIDGetter = (*GetAcceptedFrontierFailed)(nil)
_ engineTypeGetter = (*GetAcceptedFrontierFailed)(nil)

_ fmt.Stringer = (*GetAcceptedFailed)(nil)
_ chainIDGetter = (*GetAcceptedFailed)(nil)
_ requestIDGetter = (*GetAcceptedFailed)(nil)
_ engineTypeGetter = (*GetAcceptedFailed)(nil)

_ fmt.Stringer = (*GetAncestorsFailed)(nil)
_ chainIDGetter = (*GetAncestorsFailed)(nil)
_ requestIDGetter = (*GetAncestorsFailed)(nil)
_ engineTypeGetter = (*GetAncestorsFailed)(nil)

_ fmt.Stringer = (*GetFailed)(nil)
_ chainIDGetter = (*GetFailed)(nil)
_ requestIDGetter = (*GetFailed)(nil)
_ engineTypeGetter = (*GetFailed)(nil)

_ fmt.Stringer = (*QueryFailed)(nil)
_ chainIDGetter = (*QueryFailed)(nil)
_ requestIDGetter = (*QueryFailed)(nil)
_ engineTypeGetter = (*QueryFailed)(nil)

_ fmt.Stringer = (*AppRequestFailed)(nil)
_ chainIDGetter = (*AppRequestFailed)(nil)
_ requestIDGetter = (*AppRequestFailed)(nil)

_ fmt.Stringer = (*CrossChainAppRequest)(nil)
_ sourceChainIDGetter = (*CrossChainAppRequest)(nil)
_ chainIDGetter = (*CrossChainAppRequest)(nil)
_ requestIDGetter = (*CrossChainAppRequest)(nil)

_ fmt.Stringer = (*CrossChainAppRequestFailed)(nil)
_ sourceChainIDGetter = (*CrossChainAppRequestFailed)(nil)
_ chainIDGetter = (*CrossChainAppRequestFailed)(nil)
_ requestIDGetter = (*CrossChainAppRequestFailed)(nil)

_ fmt.Stringer = (*CrossChainAppResponse)(nil)
_ sourceChainIDGetter = (*CrossChainAppResponse)(nil)
_ chainIDGetter = (*CrossChainAppResponse)(nil)
_ requestIDGetter = (*CrossChainAppResponse)(nil)

_ fmt.Stringer = (*Disconnected)(nil)

_ fmt.Stringer = (*GossipRequest)(nil)

_ fmt.Stringer = (*Timeout)(nil)
)

type GetStateSummaryFrontierFailed struct {
ChainID ids.ID `json:"chain_id,omitempty"`
RequestID uint32 `json:"request_id,omitempty"`
}

func (m *GetStateSummaryFrontierFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d",
m.ChainID, m.RequestID,
)
}

func (m *GetStateSummaryFrontierFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -94,6 +119,13 @@ type GetAcceptedStateSummaryFailed struct {
RequestID uint32 `json:"request_id,omitempty"`
}

func (m *GetAcceptedStateSummaryFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d",
m.ChainID, m.RequestID,
)
}

func (m *GetAcceptedStateSummaryFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -124,6 +156,13 @@ type GetAcceptedFrontierFailed struct {
EngineType p2p.EngineType `json:"engine_type,omitempty"`
}

func (m *GetAcceptedFrontierFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d EngineType: %s",
m.ChainID, m.RequestID, m.EngineType,
)
}

func (m *GetAcceptedFrontierFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -160,6 +199,13 @@ type GetAcceptedFailed struct {
EngineType p2p.EngineType `json:"engine_type,omitempty"`
}

func (m *GetAcceptedFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d EngineType: %s",
m.ChainID, m.RequestID, m.EngineType,
)
}

func (m *GetAcceptedFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -196,6 +242,13 @@ type GetAncestorsFailed struct {
EngineType p2p.EngineType `json:"engine_type,omitempty"`
}

func (m *GetAncestorsFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d EngineType: %s",
m.ChainID, m.RequestID, m.EngineType,
)
}

func (m *GetAncestorsFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -232,6 +285,13 @@ type GetFailed struct {
EngineType p2p.EngineType `json:"engine_type,omitempty"`
}

func (m *GetFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d EngineType: %s",
m.ChainID, m.RequestID, m.EngineType,
)
}

func (m *GetFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -268,6 +328,13 @@ type QueryFailed struct {
EngineType p2p.EngineType `json:"engine_type,omitempty"`
}

func (m *QueryFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d EngineType: %s",
m.ChainID, m.RequestID, m.EngineType,
)
}

func (m *QueryFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -303,6 +370,13 @@ type AppRequestFailed struct {
RequestID uint32 `json:"request_id,omitempty"`
}

func (m *AppRequestFailed) String() string {
return fmt.Sprintf(
"ChainID: %s RequestID: %d",
m.ChainID, m.RequestID,
)
}

func (m *AppRequestFailed) GetChainId() []byte {
return m.ChainID[:]
}
Expand Down Expand Up @@ -334,6 +408,13 @@ type CrossChainAppRequest struct {
Message []byte `json:"message,omitempty"`
}

func (m *CrossChainAppRequest) String() string {
return fmt.Sprintf(
"SourceChainID: %s DestinationChainID: %s RequestID: %d Message: 0x%x",
m.SourceChainID, m.DestinationChainID, m.RequestID, m.Message,
)
}

func (m *CrossChainAppRequest) GetSourceChainID() ids.ID {
return m.SourceChainID
}
Expand Down Expand Up @@ -373,6 +454,13 @@ type CrossChainAppRequestFailed struct {
RequestID uint32 `json:"request_id,omitempty"`
}

func (m *CrossChainAppRequestFailed) String() string {
return fmt.Sprintf(
"SourceChainID: %s DestinationChainID: %s RequestID: %d",
m.SourceChainID, m.DestinationChainID, m.RequestID,
)
}

func (m *CrossChainAppRequestFailed) GetSourceChainID() ids.ID {
return m.SourceChainID
}
Expand Down Expand Up @@ -410,6 +498,13 @@ type CrossChainAppResponse struct {
Message []byte `json:"message,omitempty"`
}

func (m *CrossChainAppResponse) String() string {
return fmt.Sprintf(
"SourceChainID: %s DestinationChainID: %s RequestID: %d Message: 0x%x",
m.SourceChainID, m.DestinationChainID, m.RequestID, m.Message,
)
}

func (m *CrossChainAppResponse) GetSourceChainID() ids.ID {
return m.SourceChainID
}
Expand Down Expand Up @@ -446,6 +541,13 @@ type Connected struct {
NodeVersion *version.Application `json:"node_version,omitempty"`
}

func (m *Connected) String() string {
return fmt.Sprintf(
"NodeVersion: %s",
m.NodeVersion,
)
}

func InternalConnected(nodeID ids.NodeID, nodeVersion *version.Application) InboundMessage {
return &inboundMessage{
nodeID: nodeID,
Expand All @@ -463,6 +565,13 @@ type ConnectedSubnet struct {
SubnetID ids.ID `json:"subnet_id,omitempty"`
}

func (m *ConnectedSubnet) String() string {
return fmt.Sprintf(
"SubnetID: %s",
m.SubnetID,
)
}

// InternalConnectedSubnet returns a message that indicates the node with [nodeID] is
// connected to the subnet with the given [subnetID].
func InternalConnectedSubnet(nodeID ids.NodeID, subnetID ids.ID) InboundMessage {
Expand All @@ -478,6 +587,10 @@ func InternalConnectedSubnet(nodeID ids.NodeID, subnetID ids.ID) InboundMessage

type Disconnected struct{}

func (Disconnected) String() string {
return ""
}

func InternalDisconnected(nodeID ids.NodeID) InboundMessage {
return &inboundMessage{
nodeID: nodeID,
Expand All @@ -491,6 +604,13 @@ type VMMessage struct {
Notification uint32 `json:"notification,omitempty"`
}

func (m *VMMessage) String() string {
return fmt.Sprintf(
"Notification: %d",
m.Notification,
)
}

func InternalVMMessage(
nodeID ids.NodeID,
notification uint32,
Expand All @@ -507,6 +627,10 @@ func InternalVMMessage(

type GossipRequest struct{}

func (GossipRequest) String() string {
return ""
}

func InternalGossipRequest(
nodeID ids.NodeID,
) InboundMessage {
Expand All @@ -520,6 +644,10 @@ func InternalGossipRequest(

type Timeout struct{}

func (Timeout) String() string {
return ""
}

func InternalTimeout(nodeID ids.NodeID) InboundMessage {
return &inboundMessage{
nodeID: nodeID,
Expand Down
12 changes: 9 additions & 3 deletions message/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ var (

// InboundMessage represents a set of fields for an inbound message
type InboundMessage interface {
fmt.Stringer
// NodeID returns the ID of the node that sent this message
NodeID() ids.NodeID
// Op returns the op that describes this message type
Op() Op
// Message returns the message that was sent
Message() any
Message() fmt.Stringer
// Expiration returns the time that the sender will have already timed out
// this request
Expiration() time.Time
Expand All @@ -54,7 +55,7 @@ type InboundMessage interface {
type inboundMessage struct {
nodeID ids.NodeID
op Op
message any
message fmt.Stringer
expiration time.Time
onFinishedHandling func()
bytesSavedCompression int
Expand All @@ -68,7 +69,7 @@ func (m *inboundMessage) Op() Op {
return m.op
}

func (m *inboundMessage) Message() any {
func (m *inboundMessage) Message() fmt.Stringer {
return m.message
}

Expand All @@ -86,6 +87,11 @@ func (m *inboundMessage) BytesSavedCompression() int {
return m.bytesSavedCompression
}

func (m *inboundMessage) String() string {
return fmt.Sprintf("%s Op: %s Message: %s",
m.nodeID, m.op, m.message)
}

// OutboundMessage represents a set of fields for an outbound message that can
// be serialized into a byte stream
type OutboundMessage interface {
Expand Down
34 changes: 34 additions & 0 deletions message/messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,40 @@ func TestMessage(t *testing.T) {
}
}

// Tests the Stringer interface on inbound messages
func TestInboundMessageToString(t *testing.T) {
t.Parallel()

require := require.New(t)

mb, err := newMsgBuilder(
logging.NoLog{},
"test",
prometheus.NewRegistry(),
5*time.Second,
)
require.NoError(err)

// msg that will become the tested InboundMessage
msg := &p2p.Message{
Message: &p2p.Message_Pong{
Pong: &p2p.Pong{
Uptime: 100,
},
},
}
msgBytes, err := proto.Marshal(msg)
require.NoError(err)

inboundMsg, err := mb.parseInbound(msgBytes, ids.EmptyNodeID, func() {})
require.NoError(err)

require.Equal("NodeID-111111111111111111116DBWJs Op: pong Message: uptime:100", inboundMsg.String())

internalMsg := InternalGetStateSummaryFrontierFailed(ids.EmptyNodeID, ids.Empty, 1)
require.Equal("NodeID-111111111111111111116DBWJs Op: get_state_summary_frontier_failed Message: ChainID: 11111111111111111111111111111111LpoYY RequestID: 1", internalMsg.String())
}

func TestEmptyInboundMessage(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion message/ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func (op Op) String() string {
}
}

func Unwrap(m *p2p.Message) (interface{}, error) {
func Unwrap(m *p2p.Message) (fmt.Stringer, error) {
switch msg := m.GetMessage().(type) {
// Handshake:
case *p2p.Message_Ping:
Expand Down
Loading

0 comments on commit 400dd66

Please sign in to comment.