diff --git a/protocol/skeleton.go b/protocol/skeleton.go index 1a20303d9d..3d1b5e2b7e 100644 --- a/protocol/skeleton.go +++ b/protocol/skeleton.go @@ -2,12 +2,17 @@ package protocol import ( "context" + "errors" "fmt" "github.com/dogechain-lab/jury/protocol/proto" "github.com/dogechain-lab/jury/types" ) +var ( + ErrNilHeaderResponse = errors.New("header response is nil") +) + func getHeaders(clt proto.V1Client, req *proto.GetHeadersRequest) ([]*types.Header, error) { resp, err := clt.GetHeaders(context.Background(), req) if err != nil { @@ -17,6 +22,11 @@ func getHeaders(clt proto.V1Client, req *proto.GetHeadersRequest) ([]*types.Head headers := []*types.Header{} for _, obj := range resp.Objs { + if obj == nil || obj.Spec == nil { + // this nil header comes from a faulty node, reject all blocks of it. + return nil, ErrNilHeaderResponse + } + header := &types.Header{} if err := header.UnmarshalRLP(obj.Spec.Value); err != nil { return nil, err diff --git a/protocol/syncer.go b/protocol/syncer.go index 09ec93be3e..babf3bcb21 100644 --- a/protocol/syncer.go +++ b/protocol/syncer.go @@ -37,6 +37,9 @@ var ( ErrForkNotFound = errors.New("fork not found") ErrPopTimeout = errors.New("timeout") ErrConnectionClosed = errors.New("connection closed") + ErrTooManyHeaders = errors.New("unexpected more than 1 result") + ErrDecodeDifficulty = errors.New("failed to decode difficulty") + ErrInvalidTypeAssertion = errors.New("invalid type assertion") ) // SyncPeer is a representation of the peer the node is syncing with @@ -187,7 +190,7 @@ func statusFromProto(p *proto.V1Status) (*Status, error) { diff, ok := new(big.Int).SetString(p.Difficulty, 10) if !ok { - return nil, fmt.Errorf("failed to decode difficulty") + return nil, ErrDecodeDifficulty } s.Difficulty = diff @@ -494,7 +497,7 @@ func (s *Syncer) DeletePeer(peerID peer.ID) error { if ok { syncPeer, ok := p.(*SyncPeer) if !ok { - return errors.New("invalid type assertion") + return ErrInvalidTypeAssertion } if err := syncPeer.conn.Close(); err != nil { @@ -719,12 +722,18 @@ func getHeader(clt proto.V1Client, num *uint64, hash *types.Hash) (*types.Header } if len(resp.Objs) != 1 { - return nil, fmt.Errorf("unexpected more than 1 result") + return nil, ErrTooManyHeaders + } + + obj := resp.Objs[0] + + if obj == nil || obj.Spec == nil || len(obj.Spec.Value) == 0 { + return nil, ErrNilHeaderResponse } header := &types.Header{} - if err := header.UnmarshalRLP(resp.Objs[0].Spec.Value); err != nil { + if err := header.UnmarshalRLP(obj.Spec.Value); err != nil { return nil, err }