From 0ad98169772b0c1b197e1156e6395059f3e422e2 Mon Sep 17 00:00:00 2001 From: Wondertan Date: Mon, 31 Jul 2023 12:21:13 +0200 Subject: [PATCH] refactor: harden verification --- p2p/subscriber.go | 3 +- store/store.go | 3 +- sync/sync_head.go | 2 +- sync/sync_test.go | 3 +- sync/verify/verify.go | 112 ++++++++++++++++++++++++++++++ sync/verify/verify_test.go | 139 +++++++++++++++++++++++++++++++++++++ 6 files changed, 258 insertions(+), 4 deletions(-) create mode 100644 sync/verify/verify.go create mode 100644 sync/verify/verify_test.go diff --git a/p2p/subscriber.go b/p2p/subscriber.go index 30341333..8f89caa6 100644 --- a/p2p/subscriber.go +++ b/p2p/subscriber.go @@ -9,6 +9,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/celestiaorg/go-header" + "github.com/celestiaorg/go-header/sync/verify" ) // Subscriber manages the lifecycle and relationship of header Module @@ -77,7 +78,7 @@ func (p *Subscriber[H]) SetVerifier(val func(context.Context, H) error) error { // additional unmarhalling msg.ValidatorData = hdr - var verErr *header.VerifyError + var verErr *verify.VerifyError err = val(ctx, hdr) switch { case err == nil: diff --git a/store/store.go b/store/store.go index ae389087..afaa61ec 100644 --- a/store/store.go +++ b/store/store.go @@ -12,6 +12,7 @@ import ( logging "github.com/ipfs/go-log/v2" "github.com/celestiaorg/go-header" + "github.com/celestiaorg/go-header/sync/verify" ) var log = logging.Logger("header/store") @@ -332,7 +333,7 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { err = head.Verify(h) if err != nil { - var verErr *header.VerifyError + var verErr *verify.VerifyError if errors.As(err, &verErr) { log.Errorw("invalid header", "height_of_head", head.Height(), diff --git a/sync/sync_head.go b/sync/sync_head.go index 8d79b715..9fa3f066 100644 --- a/sync/sync_head.go +++ b/sync/sync_head.go @@ -162,7 +162,7 @@ func (s *Syncer[H]) verify(ctx context.Context, newHead H) (bool, error) { var heightThreshold int64 if s.Params.TrustingPeriod != 0 && s.Params.blockTime != 0 { buffer := time.Hour * 6 / s.Params.blockTime // small buffer to account for network delays - heightThreshold = int64(s.Params.TrustingPeriod / s.Params.blockTime + buffer) + heightThreshold = int64(s.Params.TrustingPeriod/s.Params.blockTime + buffer) } err = header.Verify(sbjHead, newHead, heightThreshold) diff --git a/sync/sync_test.go b/sync/sync_test.go index 1cc45ab3..470b8367 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -12,6 +12,7 @@ import ( "github.com/celestiaorg/go-header/headertest" "github.com/celestiaorg/go-header/local" "github.com/celestiaorg/go-header/store" + "github.com/celestiaorg/go-header/sync/verify" ) func TestSyncSimpleRequestingHead(t *testing.T) { @@ -277,7 +278,7 @@ func TestSyncerIncomingDuplicate(t *testing.T) { time.Sleep(time.Millisecond * 10) - var verErr *header.VerifyError + var verErr *verify.VerifyError err = syncer.incomingNetworkHead(ctx, range1[len(range1)-1]) assert.ErrorAs(t, err, &verErr) diff --git a/sync/verify/verify.go b/sync/verify/verify.go new file mode 100644 index 00000000..32255def --- /dev/null +++ b/sync/verify/verify.go @@ -0,0 +1,112 @@ +// TODO(@Wondertan): Should be just part of sync pkg and not subpkg +// +// Fix after adjacency requirement is removed from the Store. +package verify + +import ( + "errors" + "fmt" + "time" + + "github.com/celestiaorg/go-header" +) + +// DefaultHeightThreshold defines default height threshold beyond which headers are rejected +// NOTE: Compared against subjective head which is guaranteed to be non-expired +const DefaultHeightThreshold int64 = 40000 // ~ 7 days of 15 second headers + +// VerifyError is thrown during for Headers failed verification. +type VerifyError struct { + // Reason why verification failed as inner error. + Reason error + // SoftFailure means verification did not have enough information to definitively conclude a + // Header was correct or not. + // May happen with recent Headers during unfinished historical sync or because of local errors. + // TODO(@Wondertan): Better be part of signature Header.Verify() (bool, error), but kept here + // not to break + SoftFailure bool +} + +func (vr *VerifyError) Error() string { + return fmt.Sprintf("header: verify: %s", vr.Reason.Error()) +} + +func (vr *VerifyError) Unwrap() error { + return vr.Reason +} + +// Verify verifies untrusted Header against trusted following general Header checks and +// custom user-specific checks defined in Header.Verify +// +// If heightThreshold is zero, uses DefaultHeightThreshold. +// Always returns VerifyError. +func Verify[H header.Header](trstd, untrstd H, heightThreshold int64) error { + // general mandatory verification + err := verify[H](trstd, untrstd, heightThreshold) + if err != nil { + return &VerifyError{Reason: err} + } + // user defined verification + err = trstd.Verify(untrstd) + if err == nil { + return nil + } + // if that's an error, ensure we always return VerifyError + var verErr *VerifyError + if !errors.As(err, &verErr) { + verErr = &VerifyError{Reason: err} + } + // check adjacency of failed verification + adjacent := untrstd.Height() == trstd.Height()+1 + if !adjacent { + // if non-adjacent, we don't know if the header is *really* wrong + // so set as soft + verErr.SoftFailure = true + } + // we trust adjacent verification to it's fullest + // if verification fails - the header is *really* wrong + return verErr +} + +// verify is a little bro of Verify yet performs mandatory Header checks +// for any Header implementation. +func verify[H header.Header](trstd, untrstd H, heightThreshold int64) error { + if heightThreshold == 0 { + heightThreshold = DefaultHeightThreshold + } + + if untrstd.IsZero() { + return fmt.Errorf("zero header") + } + + if untrstd.ChainID() != trstd.ChainID() { + return fmt.Errorf("wrong header chain id %s, not %s", untrstd.ChainID(), trstd.ChainID()) + } + + if !untrstd.Time().After(trstd.Time()) { + return fmt.Errorf("unordered header timestamp %v is before %v", untrstd.Time(), trstd.Time()) + } + + now := time.Now() + if !untrstd.Time().Before(now.Add(clockDrift)) { + return fmt.Errorf("header timestamp %v is from future (now: %v, clock_drift: %v)", untrstd.Time(), now, clockDrift) + } + + known := untrstd.Height() <= trstd.Height() + if known { + return fmt.Errorf("known header height %d, current %d", untrstd.Height(), trstd.Height()) + } + // reject headers with height too far from the future + // this is essential for headers failed non-adjacent verification + // yet taken as sync target + adequateHeight := untrstd.Height()-trstd.Height() < heightThreshold + if !adequateHeight { + return fmt.Errorf("header height %d is far from future (current: %d, threshold: %d)", untrstd.Height(), trstd.Height(), heightThreshold) + } + + return nil +} + +// clockDrift defines how much new header's time can drift into +// the future relative to the now time during verification. +var clockDrift = 10 * time.Second diff --git a/sync/verify/verify_test.go b/sync/verify/verify_test.go new file mode 100644 index 00000000..954fd73c --- /dev/null +++ b/sync/verify/verify_test.go @@ -0,0 +1,139 @@ +package verify + +import ( + "errors" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/celestiaorg/go-header/headertest" +) + +func TestVerify(t *testing.T) { + suite := headertest.NewTestSuite(t) + trusted := suite.GenDummyHeaders(1)[0] + + tests := []struct { + prepare func() *headertest.DummyHeader + err bool + soft bool + }{ + { + prepare: func() *headertest.DummyHeader { + return nil + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.VerifyFailure = true + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.VerifyFailure = true + return untrusted + }, + err: true, + soft: true, // soft because non-adjacent + }, + { + prepare: func() *headertest.DummyHeader { + return suite.NextHeader() + }, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + err := Verify(trusted, test.prepare(), 0) + if test.err { + var verErr *VerifyError + assert.ErrorAs(t, err, &verErr) + assert.NotNil(t, errors.Unwrap(verErr)) + assert.Equal(t, test.soft, verErr.SoftFailure) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_verify(t *testing.T) { + suite := headertest.NewTestSuite(t) + trusted := suite.GenDummyHeaders(1)[0] + + tests := []struct { + prepare func() *headertest.DummyHeader + err bool + }{ + { + prepare: func() *headertest.DummyHeader { + return suite.NextHeader() + }, + }, + { + prepare: func() *headertest.DummyHeader { + return nil + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.ChainID = "gtmb" + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Time = untrusted.Raw.Time.Truncate(time.Minute * 10) + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Time = untrusted.Raw.Time.Add(time.Minute) + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Height = trusted.Height() + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Height += 100000 + return untrusted + }, + err: true, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + err := verify(trusted, test.prepare(), 0) + if test.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +}