diff --git a/plugin/evm/validators/codec.go b/plugin/evm/validators/codec.go new file mode 100644 index 0000000000..dadba8b273 --- /dev/null +++ b/plugin/evm/validators/codec.go @@ -0,0 +1,34 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package validators + +import ( + "math" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/codec/linearcodec" + "github.com/ava-labs/avalanchego/utils/wrappers" +) + +const ( + codecVersion = uint16(0) +) + +var vdrCodec codec.Manager + +func init() { + vdrCodec = codec.NewManager(math.MaxInt32) + c := linearcodec.NewDefault() + + errs := wrappers.Errs{} + errs.Add( + c.RegisterType(validatorData{}), + + vdrCodec.RegisterCodec(codecVersion, c), + ) + + if errs.Errored() { + panic(errs.Err) + } +} diff --git a/plugin/evm/validators/mock_listener.go b/plugin/evm/validators/mock_listener.go new file mode 100644 index 0000000000..d67703007d --- /dev/null +++ b/plugin/evm/validators/mock_listener.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ava-labs/subnet-evm/plugin/evm/validators (interfaces: StateCallbackListener) +// +// Generated by this command: +// +// mockgen -package=validators -destination=plugin/evm/validators/mock_listener.go github.com/ava-labs/subnet-evm/plugin/evm/validators StateCallbackListener +// + +// Package validators is a generated GoMock package. +package validators + +import ( + reflect "reflect" + + ids "github.com/ava-labs/avalanchego/ids" + gomock "go.uber.org/mock/gomock" +) + +// MockStateCallbackListener is a mock of StateCallbackListener interface. +type MockStateCallbackListener struct { + ctrl *gomock.Controller + recorder *MockStateCallbackListenerMockRecorder +} + +// MockStateCallbackListenerMockRecorder is the mock recorder for MockStateCallbackListener. +type MockStateCallbackListenerMockRecorder struct { + mock *MockStateCallbackListener +} + +// NewMockStateCallbackListener creates a new mock instance. +func NewMockStateCallbackListener(ctrl *gomock.Controller) *MockStateCallbackListener { + mock := &MockStateCallbackListener{ctrl: ctrl} + mock.recorder = &MockStateCallbackListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStateCallbackListener) EXPECT() *MockStateCallbackListenerMockRecorder { + return m.recorder +} + +// OnValidatorAdded mocks base method. +func (m *MockStateCallbackListener) OnValidatorAdded(arg0 ids.ID, arg1 ids.NodeID, arg2 uint64, arg3 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnValidatorAdded", arg0, arg1, arg2, arg3) +} + +// OnValidatorAdded indicates an expected call of OnValidatorAdded. +func (mr *MockStateCallbackListenerMockRecorder) OnValidatorAdded(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnValidatorAdded", reflect.TypeOf((*MockStateCallbackListener)(nil).OnValidatorAdded), arg0, arg1, arg2, arg3) +} + +// OnValidatorRemoved mocks base method. +func (m *MockStateCallbackListener) OnValidatorRemoved(arg0 ids.ID, arg1 ids.NodeID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnValidatorRemoved", arg0, arg1) +} + +// OnValidatorRemoved indicates an expected call of OnValidatorRemoved. +func (mr *MockStateCallbackListenerMockRecorder) OnValidatorRemoved(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnValidatorRemoved", reflect.TypeOf((*MockStateCallbackListener)(nil).OnValidatorRemoved), arg0, arg1) +} + +// OnValidatorStatusUpdated mocks base method. +func (m *MockStateCallbackListener) OnValidatorStatusUpdated(arg0 ids.ID, arg1 ids.NodeID, arg2 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnValidatorStatusUpdated", arg0, arg1, arg2) +} + +// OnValidatorStatusUpdated indicates an expected call of OnValidatorStatusUpdated. +func (mr *MockStateCallbackListenerMockRecorder) OnValidatorStatusUpdated(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnValidatorStatusUpdated", reflect.TypeOf((*MockStateCallbackListener)(nil).OnValidatorStatusUpdated), arg0, arg1, arg2) +} diff --git a/plugin/evm/validators/state.go b/plugin/evm/validators/state.go new file mode 100644 index 0000000000..f30418c220 --- /dev/null +++ b/plugin/evm/validators/state.go @@ -0,0 +1,327 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package validators + +import ( + "fmt" + "time" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/uptime" + "github.com/ava-labs/avalanchego/utils/set" +) + +var _ uptime.State = &state{} + +type dbUpdateStatus bool + +var ErrAlreadyExists = fmt.Errorf("validator already exists") + +const ( + updated dbUpdateStatus = true + deleted dbUpdateStatus = false +) + +type State interface { + uptime.State + // AddValidator adds a new validator to the state + AddValidator(vID ids.ID, nodeID ids.NodeID, startTimestamp uint64, isActive bool) error + // DeleteValidator deletes the validator from the state + DeleteValidator(vID ids.ID) error + // WriteState writes the validator state to the disk + WriteState() error + + // SetStatus sets the active status of the validator with the given vID + SetStatus(vID ids.ID, isActive bool) error + // GetStatus returns the active status of the validator with the given vID + GetStatus(vID ids.ID) (bool, error) + + // GetValidationIDs returns the validation IDs in the state + GetValidationIDs() set.Set[ids.ID] + // GetValidatorIDs returns the validator node IDs in the state + GetValidatorIDs() set.Set[ids.NodeID] + + // RegisterListener registers a listener to the state + RegisterListener(StateCallbackListener) +} + +// StateCallbackListener is a listener for the validator state +type StateCallbackListener interface { + // OnValidatorAdded is called when a new validator is added + OnValidatorAdded(vID ids.ID, nodeID ids.NodeID, startTime uint64, isActive bool) + // OnValidatorRemoved is called when a validator is removed + OnValidatorRemoved(vID ids.ID, nodeID ids.NodeID) + // OnValidatorStatusUpdated is called when a validator status is updated + OnValidatorStatusUpdated(vID ids.ID, nodeID ids.NodeID, isActive bool) +} + +type validatorData struct { + UpDuration time.Duration `serialize:"true"` + LastUpdated uint64 `serialize:"true"` + NodeID ids.NodeID `serialize:"true"` + StartTime uint64 `serialize:"true"` + IsActive bool `serialize:"true"` + + validationID ids.ID // database key +} + +type state struct { + data map[ids.ID]*validatorData // vID -> validatorData + index map[ids.NodeID]ids.ID // nodeID -> vID + // updatedData tracks the updates since WriteValidator was last called + updatedData map[ids.ID]dbUpdateStatus // vID -> updated status + db database.Database + + listeners []StateCallbackListener +} + +// NewState creates a new State, it also loads the data from the disk +func NewState(db database.Database) (State, error) { + s := &state{ + index: make(map[ids.NodeID]ids.ID), + data: make(map[ids.ID]*validatorData), + updatedData: make(map[ids.ID]dbUpdateStatus), + db: db, + } + if err := s.loadFromDisk(); err != nil { + return nil, fmt.Errorf("failed to load data from disk: %w", err) + } + return s, nil +} + +// GetUptime returns the uptime of the validator with the given nodeID +func (s *state) GetUptime( + nodeID ids.NodeID, +) (time.Duration, time.Time, error) { + data, err := s.getData(nodeID) + if err != nil { + return 0, time.Time{}, err + } + return data.UpDuration, data.getLastUpdated(), nil +} + +// SetUptime sets the uptime of the validator with the given nodeID +func (s *state) SetUptime( + nodeID ids.NodeID, + upDuration time.Duration, + lastUpdated time.Time, +) error { + data, err := s.getData(nodeID) + if err != nil { + return err + } + data.UpDuration = upDuration + data.setLastUpdated(lastUpdated) + + s.updatedData[data.validationID] = updated + return nil +} + +// GetStartTime returns the start time of the validator with the given nodeID +func (s *state) GetStartTime(nodeID ids.NodeID) (time.Time, error) { + data, err := s.getData(nodeID) + if err != nil { + return time.Time{}, err + } + return data.getStartTime(), nil +} + +// AddValidator adds a new validator to the state +// the new validator is marked as updated and will be written to the disk when WriteState is called +func (s *state) AddValidator(vID ids.ID, nodeID ids.NodeID, startTimestamp uint64, isActive bool) error { + data := &validatorData{ + NodeID: nodeID, + validationID: vID, + IsActive: isActive, + StartTime: startTimestamp, + UpDuration: 0, + LastUpdated: startTimestamp, + } + if err := s.addData(vID, data); err != nil { + return err + } + + s.updatedData[vID] = updated + + for _, listener := range s.listeners { + listener.OnValidatorAdded(vID, nodeID, startTimestamp, isActive) + } + return nil +} + +// DeleteValidator marks the validator as deleted +// marked validator will be deleted from disk when WriteState is called +func (s *state) DeleteValidator(vID ids.ID) error { + data, exists := s.data[vID] + if !exists { + return database.ErrNotFound + } + delete(s.data, data.validationID) + delete(s.index, data.NodeID) + + // mark as deleted for WriteValidator + s.updatedData[data.validationID] = deleted + + for _, listener := range s.listeners { + listener.OnValidatorRemoved(vID, data.NodeID) + } + return nil +} + +// WriteState writes the updated state to the disk +func (s *state) WriteState() error { + // TODO: consider adding batch size + batch := s.db.NewBatch() + for vID, updateStatus := range s.updatedData { + switch updateStatus { + case updated: + data := s.data[vID] + + dataBytes, err := vdrCodec.Marshal(codecVersion, data) + if err != nil { + return err + } + if err := batch.Put(vID[:], dataBytes); err != nil { + return err + } + case deleted: + if err := batch.Delete(vID[:]); err != nil { + return err + } + default: + return fmt.Errorf("unknown update status for %s", vID) + } + // we're done, remove the updated marker + delete(s.updatedData, vID) + } + return batch.Write() +} + +// SetStatus sets the active status of the validator with the given vID +func (s *state) SetStatus(vID ids.ID, isActive bool) error { + data, exists := s.data[vID] + if !exists { + return database.ErrNotFound + } + data.IsActive = isActive + s.updatedData[vID] = updated + + for _, listener := range s.listeners { + listener.OnValidatorStatusUpdated(vID, data.NodeID, isActive) + } + return nil +} + +// GetStatus returns the active status of the validator with the given vID +func (s *state) GetStatus(vID ids.ID) (bool, error) { + data, exists := s.data[vID] + if !exists { + return false, database.ErrNotFound + } + return data.IsActive, nil +} + +// GetValidationIDs returns the validation IDs in the state +func (s *state) GetValidationIDs() set.Set[ids.ID] { + ids := set.NewSet[ids.ID](len(s.data)) + for vID := range s.data { + ids.Add(vID) + } + return ids +} + +// GetValidatorIDs returns the validator IDs in the state +func (s *state) GetValidatorIDs() set.Set[ids.NodeID] { + ids := set.NewSet[ids.NodeID](len(s.index)) + for nodeID := range s.index { + ids.Add(nodeID) + } + return ids +} + +// RegisterListener registers a listener to the state +// OnValidatorAdded is called for all current validators on the provided listener before this function returns +func (s *state) RegisterListener(listener StateCallbackListener) { + s.listeners = append(s.listeners, listener) + + // notify the listener of the current state + for vID, data := range s.data { + listener.OnValidatorAdded(vID, data.NodeID, data.StartTime, data.IsActive) + } +} + +// parseValidatorData parses the data from the bytes into given validatorData +func parseValidatorData(bytes []byte, data *validatorData) error { + if len(bytes) != 0 { + if _, err := vdrCodec.Unmarshal(bytes, data); err != nil { + return err + } + } + return nil +} + +// Load the state from the disk +func (s *state) loadFromDisk() error { + it := s.db.NewIterator() + defer it.Release() + for it.Next() { + vIDBytes := it.Key() + vID, err := ids.ToID(vIDBytes) + if err != nil { + return fmt.Errorf("failed to parse validator ID: %w", err) + } + vdr := &validatorData{ + validationID: vID, + } + if err := parseValidatorData(it.Value(), vdr); err != nil { + return fmt.Errorf("failed to parse validator data: %w", err) + } + if err := s.addData(vID, vdr); err != nil { + return err + } + } + return it.Error() +} + +// addData adds the data to the state +// returns an error if the data already exists +func (s *state) addData(vID ids.ID, data *validatorData) error { + if _, exists := s.data[vID]; exists { + return fmt.Errorf("%w, vID: %s", ErrAlreadyExists, vID) + } + if _, exists := s.index[data.NodeID]; exists { + return fmt.Errorf("%w, nodeID: %s", ErrAlreadyExists, data.NodeID) + } + + s.data[vID] = data + s.index[data.NodeID] = vID + return nil +} + +// getData returns the data for the validator with the given nodeID +// returns database.ErrNotFound if the data does not exist +func (s *state) getData(nodeID ids.NodeID) (*validatorData, error) { + vID, exists := s.index[nodeID] + if !exists { + return nil, database.ErrNotFound + } + data, exists := s.data[vID] + if !exists { + return nil, database.ErrNotFound + } + return data, nil +} + +func (v *validatorData) setLastUpdated(t time.Time) { + v.LastUpdated = uint64(t.Unix()) +} + +func (v *validatorData) getLastUpdated() time.Time { + return time.Unix(int64(v.LastUpdated), 0) +} + +func (v *validatorData) getStartTime() time.Time { + return time.Unix(int64(v.StartTime), 0) +} diff --git a/plugin/evm/validators/state_test.go b/plugin/evm/validators/state_test.go new file mode 100644 index 0000000000..ecfd7d34a9 --- /dev/null +++ b/plugin/evm/validators/state_test.go @@ -0,0 +1,250 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package validators + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/memdb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/wrappers" +) + +func TestState(t *testing.T) { + require := require.New(t) + db := memdb.New() + state, err := NewState(db) + require.NoError(err) + + // get non-existent uptime + nodeID := ids.GenerateTestNodeID() + vID := ids.GenerateTestID() + _, _, err = state.GetUptime(nodeID) + require.ErrorIs(err, database.ErrNotFound) + + // set non-existent uptime + startTime := time.Now() + err = state.SetUptime(nodeID, 1, startTime) + require.ErrorIs(err, database.ErrNotFound) + + // add new validator + state.AddValidator(vID, nodeID, uint64(startTime.Unix()), true) + + // adding the same validator should fail + err = state.AddValidator(vID, ids.GenerateTestNodeID(), uint64(startTime.Unix()), true) + require.ErrorIs(err, ErrAlreadyExists) + // adding the same nodeID should fail + err = state.AddValidator(ids.GenerateTestID(), nodeID, uint64(startTime.Unix()), true) + require.ErrorIs(err, ErrAlreadyExists) + + // get uptime + uptime, lastUpdated, err := state.GetUptime(nodeID) + require.NoError(err) + require.Equal(time.Duration(0), uptime) + require.Equal(startTime.Unix(), lastUpdated.Unix()) + + // set uptime + newUptime := 2 * time.Minute + newLastUpdated := lastUpdated.Add(time.Hour) + require.NoError(state.SetUptime(nodeID, newUptime, newLastUpdated)) + // get new uptime + uptime, lastUpdated, err = state.GetUptime(nodeID) + require.NoError(err) + require.Equal(newUptime, uptime) + require.Equal(newLastUpdated, lastUpdated) + + // set status + require.NoError(state.SetStatus(vID, false)) + // get status + status, err := state.GetStatus(vID) + require.NoError(err) + require.False(status) + + // delete uptime + require.NoError(state.DeleteValidator(vID)) + + // get deleted uptime + _, _, err = state.GetUptime(nodeID) + require.ErrorIs(err, database.ErrNotFound) +} + +func TestWriteValidator(t *testing.T) { + require := require.New(t) + db := memdb.New() + state, err := NewState(db) + require.NoError(err) + // write empty uptimes + require.NoError(state.WriteState()) + + // load uptime + nodeID := ids.GenerateTestNodeID() + vID := ids.GenerateTestID() + startTime := time.Now() + require.NoError(state.AddValidator(vID, nodeID, uint64(startTime.Unix()), true)) + + // write state, should reflect to DB + require.NoError(state.WriteState()) + require.True(db.Has(vID[:])) + + // set uptime + newUptime := 2 * time.Minute + newLastUpdated := startTime.Add(time.Hour) + require.NoError(state.SetUptime(nodeID, newUptime, newLastUpdated)) + require.NoError(state.WriteState()) + + // refresh state, should load from DB + state, err = NewState(db) + require.NoError(err) + + // get uptime + uptime, lastUpdated, err := state.GetUptime(nodeID) + require.NoError(err) + require.Equal(newUptime, uptime) + require.Equal(newLastUpdated.Unix(), lastUpdated.Unix()) + + // delete + require.NoError(state.DeleteValidator(vID)) + + // write state, should reflect to DB + require.NoError(state.WriteState()) + require.False(db.Has(vID[:])) +} + +func TestParseValidator(t *testing.T) { + testNodeID, err := ids.NodeIDFromString("NodeID-CaBYJ9kzHvrQFiYWowMkJGAQKGMJqZoat") + require.NoError(t, err) + type test struct { + name string + bytes []byte + expected *validatorData + expectedErr error + } + tests := []test{ + { + name: "nil", + bytes: nil, + expected: &validatorData{ + LastUpdated: 0, + StartTime: 0, + }, + expectedErr: nil, + }, + { + name: "empty", + bytes: []byte{}, + expected: &validatorData{ + LastUpdated: 0, + StartTime: 0, + }, + expectedErr: nil, + }, + { + name: "valid", + bytes: []byte{ + // codec version + 0x00, 0x00, + // up duration + 0x00, 0x00, 0x00, 0x00, 0x00, 0x5B, 0x8D, 0x80, + // last updated + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0D, 0xBB, 0xA0, + // node ID + 0x7e, 0xef, 0xe8, 0x8a, 0x45, 0xfb, 0x7a, 0xc4, + 0xb0, 0x59, 0xc9, 0x33, 0x71, 0x0a, 0x57, 0x33, + 0xff, 0x9f, 0x4b, 0xab, + // start time + 0x00, 0x00, 0x00, 0x00, 0x00, 0x5B, 0x8D, 0x80, + // status + 0x01, + }, + expected: &validatorData{ + UpDuration: time.Duration(6000000), + LastUpdated: 900000, + NodeID: testNodeID, + StartTime: 6000000, + IsActive: true, + }, + }, + { + name: "invalid codec version", + bytes: []byte{ + // codec version + 0x00, 0x02, + // up duration + 0x00, 0x00, 0x00, 0x00, 0x00, 0x5B, 0x8D, 0x80, + // last updated + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0D, 0xBB, 0xA0, + }, + expected: nil, + expectedErr: codec.ErrUnknownVersion, + }, + { + name: "short byte len", + bytes: []byte{ + // codec version + 0x00, 0x00, + // up duration + 0x00, 0x00, 0x00, 0x00, 0x00, 0x5B, 0x8D, 0x80, + // last updated + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0D, 0xBB, 0xA0, + }, + expected: nil, + expectedErr: wrappers.ErrInsufficientLength, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + var data validatorData + err := parseValidatorData(tt.bytes, &data) + require.ErrorIs(err, tt.expectedErr) + if tt.expectedErr != nil { + return + } + require.Equal(tt.expected, &data) + }) + } +} + +func TestStateListener(t *testing.T) { + require := require.New(t) + db := memdb.New() + state, err := NewState(db) + require.NoError(err) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expectedvID := ids.GenerateTestID() + expectedNodeID := ids.GenerateTestNodeID() + expectedStartTime := time.Now() + mockListener := NewMockStateCallbackListener(ctrl) + // add initial validator to test RegisterListener + initialvID := ids.GenerateTestID() + initialNodeID := ids.GenerateTestNodeID() + initialStartTime := time.Now() + + // add initial validator + require.NoError(state.AddValidator(initialvID, initialNodeID, uint64(initialStartTime.Unix()), true)) + + // register listener + mockListener.EXPECT().OnValidatorAdded(initialvID, initialNodeID, uint64(initialStartTime.Unix()), true) + state.RegisterListener(mockListener) + + // add new validator + mockListener.EXPECT().OnValidatorAdded(expectedvID, expectedNodeID, uint64(expectedStartTime.Unix()), true) + require.NoError(state.AddValidator(expectedvID, expectedNodeID, uint64(expectedStartTime.Unix()), true)) + + // set status + mockListener.EXPECT().OnValidatorStatusUpdated(expectedvID, expectedNodeID, false) + require.NoError(state.SetStatus(expectedvID, false)) + + // remove validator + mockListener.EXPECT().OnValidatorRemoved(expectedvID, expectedNodeID) + require.NoError(state.DeleteValidator(expectedvID)) +} diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index 391dc8e13c..73aa62ccd7 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -1,2 +1,3 @@ github.com/ava-labs/subnet-evm/precompile/precompileconfig=Predicater,Config,ChainConfig,Accepter=precompile/precompileconfig/mocks.go github.com/ava-labs/subnet-evm/precompile/contract=BlockContext,AccessibleState,StateDB=precompile/contract/mocks.go +github.com/ava-labs/subnet-evm/plugin/evm/validators=StateCallbackListener=plugin/evm/validators/mock_listener.go