Skip to content

Commit

Permalink
refactor: harden verification
Browse files Browse the repository at this point in the history
  • Loading branch information
Wondertan committed Aug 24, 2023
1 parent f79c35c commit 0ad9816
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 4 deletions.
3 changes: 2 additions & 1 deletion p2p/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/libp2p/go-libp2p/core/peer"

"github.com/celestiaorg/go-header"
"github.com/celestiaorg/go-header/sync/verify"
)

// Subscriber manages the lifecycle and relationship of header Module
Expand Down Expand Up @@ -77,7 +78,7 @@ func (p *Subscriber[H]) SetVerifier(val func(context.Context, H) error) error {
// additional unmarhalling
msg.ValidatorData = hdr

var verErr *header.VerifyError
var verErr *verify.VerifyError
err = val(ctx, hdr)
switch {
case err == nil:
Expand Down
3 changes: 2 additions & 1 deletion store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
logging "github.com/ipfs/go-log/v2"

"github.com/celestiaorg/go-header"
"github.com/celestiaorg/go-header/sync/verify"
)

var log = logging.Logger("header/store")
Expand Down Expand Up @@ -332,7 +333,7 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error {

err = head.Verify(h)
if err != nil {
var verErr *header.VerifyError
var verErr *verify.VerifyError
if errors.As(err, &verErr) {
log.Errorw("invalid header",
"height_of_head", head.Height(),
Expand Down
2 changes: 1 addition & 1 deletion sync/sync_head.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (s *Syncer[H]) verify(ctx context.Context, newHead H) (bool, error) {
var heightThreshold int64
if s.Params.TrustingPeriod != 0 && s.Params.blockTime != 0 {
buffer := time.Hour * 6 / s.Params.blockTime // small buffer to account for network delays
heightThreshold = int64(s.Params.TrustingPeriod / s.Params.blockTime + buffer)
heightThreshold = int64(s.Params.TrustingPeriod/s.Params.blockTime + buffer)
}

err = header.Verify(sbjHead, newHead, heightThreshold)
Expand Down
3 changes: 2 additions & 1 deletion sync/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/celestiaorg/go-header/headertest"
"github.com/celestiaorg/go-header/local"
"github.com/celestiaorg/go-header/store"
"github.com/celestiaorg/go-header/sync/verify"
)

func TestSyncSimpleRequestingHead(t *testing.T) {
Expand Down Expand Up @@ -277,7 +278,7 @@ func TestSyncerIncomingDuplicate(t *testing.T) {

time.Sleep(time.Millisecond * 10)

var verErr *header.VerifyError
var verErr *verify.VerifyError
err = syncer.incomingNetworkHead(ctx, range1[len(range1)-1])
assert.ErrorAs(t, err, &verErr)

Expand Down
112 changes: 112 additions & 0 deletions sync/verify/verify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// TODO(@Wondertan): Should be just part of sync pkg and not subpkg
//
// Fix after adjacency requirement is removed from the Store.
package verify

import (
"errors"
"fmt"
"time"

"github.com/celestiaorg/go-header"
)

// DefaultHeightThreshold defines default height threshold beyond which headers are rejected
// NOTE: Compared against subjective head which is guaranteed to be non-expired
const DefaultHeightThreshold int64 = 40000 // ~ 7 days of 15 second headers

// VerifyError is thrown during for Headers failed verification.
type VerifyError struct {
// Reason why verification failed as inner error.
Reason error
// SoftFailure means verification did not have enough information to definitively conclude a
// Header was correct or not.
// May happen with recent Headers during unfinished historical sync or because of local errors.
// TODO(@Wondertan): Better be part of signature Header.Verify() (bool, error), but kept here
// not to break
SoftFailure bool
}

func (vr *VerifyError) Error() string {
return fmt.Sprintf("header: verify: %s", vr.Reason.Error())
}

func (vr *VerifyError) Unwrap() error {
return vr.Reason
}

// Verify verifies untrusted Header against trusted following general Header checks and
// custom user-specific checks defined in Header.Verify
//
// If heightThreshold is zero, uses DefaultHeightThreshold.
// Always returns VerifyError.
func Verify[H header.Header](trstd, untrstd H, heightThreshold int64) error {
// general mandatory verification
err := verify[H](trstd, untrstd, heightThreshold)
if err != nil {
return &VerifyError{Reason: err}
}
// user defined verification
err = trstd.Verify(untrstd)
if err == nil {
return nil
}
// if that's an error, ensure we always return VerifyError
var verErr *VerifyError
if !errors.As(err, &verErr) {
verErr = &VerifyError{Reason: err}
}
// check adjacency of failed verification
adjacent := untrstd.Height() == trstd.Height()+1
if !adjacent {
// if non-adjacent, we don't know if the header is *really* wrong
// so set as soft
verErr.SoftFailure = true
}
// we trust adjacent verification to it's fullest
// if verification fails - the header is *really* wrong
return verErr
}

// verify is a little bro of Verify yet performs mandatory Header checks
// for any Header implementation.
func verify[H header.Header](trstd, untrstd H, heightThreshold int64) error {
if heightThreshold == 0 {
heightThreshold = DefaultHeightThreshold
}

if untrstd.IsZero() {
return fmt.Errorf("zero header")
}

if untrstd.ChainID() != trstd.ChainID() {
return fmt.Errorf("wrong header chain id %s, not %s", untrstd.ChainID(), trstd.ChainID())
}

if !untrstd.Time().After(trstd.Time()) {
return fmt.Errorf("unordered header timestamp %v is before %v", untrstd.Time(), trstd.Time())
}

now := time.Now()
if !untrstd.Time().Before(now.Add(clockDrift)) {
return fmt.Errorf("header timestamp %v is from future (now: %v, clock_drift: %v)", untrstd.Time(), now, clockDrift)
}

known := untrstd.Height() <= trstd.Height()
if known {
return fmt.Errorf("known header height %d, current %d", untrstd.Height(), trstd.Height())
}
// reject headers with height too far from the future
// this is essential for headers failed non-adjacent verification
// yet taken as sync target
adequateHeight := untrstd.Height()-trstd.Height() < heightThreshold
if !adequateHeight {
return fmt.Errorf("header height %d is far from future (current: %d, threshold: %d)", untrstd.Height(), trstd.Height(), heightThreshold)
}

return nil
}

// clockDrift defines how much new header's time can drift into
// the future relative to the now time during verification.
var clockDrift = 10 * time.Second
139 changes: 139 additions & 0 deletions sync/verify/verify_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package verify

import (
"errors"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/celestiaorg/go-header/headertest"
)

func TestVerify(t *testing.T) {
suite := headertest.NewTestSuite(t)
trusted := suite.GenDummyHeaders(1)[0]

tests := []struct {
prepare func() *headertest.DummyHeader
err bool
soft bool
}{
{
prepare: func() *headertest.DummyHeader {
return nil
},
err: true,
},
{
prepare: func() *headertest.DummyHeader {
untrusted := suite.NextHeader()
untrusted.VerifyFailure = true
return untrusted
},
err: true,
},
{
prepare: func() *headertest.DummyHeader {
untrusted := suite.NextHeader()
untrusted.VerifyFailure = true
return untrusted
},
err: true,
soft: true, // soft because non-adjacent
},
{
prepare: func() *headertest.DummyHeader {
return suite.NextHeader()
},
},
}

for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
err := Verify(trusted, test.prepare(), 0)
if test.err {
var verErr *VerifyError
assert.ErrorAs(t, err, &verErr)
assert.NotNil(t, errors.Unwrap(verErr))
assert.Equal(t, test.soft, verErr.SoftFailure)
} else {
assert.NoError(t, err)
}
})
}
}

func Test_verify(t *testing.T) {
suite := headertest.NewTestSuite(t)
trusted := suite.GenDummyHeaders(1)[0]

tests := []struct {
prepare func() *headertest.DummyHeader
err bool
}{
{
prepare: func() *headertest.DummyHeader {
return suite.NextHeader()
},
},
{
prepare: func() *headertest.DummyHeader {
return nil
},
err: true,
},
{
prepare: func() *headertest.DummyHeader {
untrusted := suite.NextHeader()
untrusted.Raw.ChainID = "gtmb"
return untrusted
},
err: true,
},
{
prepare: func() *headertest.DummyHeader {
untrusted := suite.NextHeader()
untrusted.Raw.Time = untrusted.Raw.Time.Truncate(time.Minute * 10)
return untrusted
},
err: true,
},
{
prepare: func() *headertest.DummyHeader {
untrusted := suite.NextHeader()
untrusted.Raw.Time = untrusted.Raw.Time.Add(time.Minute)
return untrusted
},
err: true,
},
{
prepare: func() *headertest.DummyHeader {
untrusted := suite.NextHeader()
untrusted.Raw.Height = trusted.Height()
return untrusted
},
err: true,
},
{
prepare: func() *headertest.DummyHeader {
untrusted := suite.NextHeader()
untrusted.Raw.Height += 100000
return untrusted
},
err: true,
},
}

for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
err := verify(trusted, test.prepare(), 0)
if test.err {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

0 comments on commit 0ad9816

Please sign in to comment.