diff --git a/validator/client/BUILD.bazel b/validator/client/BUILD.bazel index c04575ece14b..b5ae1700fdb6 100644 --- a/validator/client/BUILD.bazel +++ b/validator/client/BUILD.bazel @@ -46,6 +46,7 @@ go_library( "//validator/graffiti:go_default_library", "//validator/keymanager:go_default_library", "//validator/keymanager/imported:go_default_library", + "//validator/keymanager/remote:go_default_library", "//validator/slashing-protection/iface:go_default_library", "@com_github_dgraph_io_ristretto//:go_default_library", "@com_github_gogo_protobuf//proto:go_default_library", @@ -107,6 +108,7 @@ go_test( "//shared/mock:go_default_library", "//shared/params:go_default_library", "//shared/slotutil:go_default_library", + "//shared/slotutil/testing:go_default_library", "//shared/testutil:go_default_library", "//shared/testutil/assert:go_default_library", "//shared/testutil/require:go_default_library", @@ -117,6 +119,7 @@ go_test( "//validator/db/testing:go_default_library", "//validator/graffiti:go_default_library", "//validator/keymanager/derived:go_default_library", + "//validator/keymanager/remote:go_default_library", "//validator/slashing-protection/local/standard-protection-format:go_default_library", "//validator/testing:go_default_library", "@com_github_gogo_protobuf//types:go_default_library", diff --git a/validator/client/key_reload_test.go b/validator/client/key_reload_test.go index 3a068ac4535b..51a0b30a0f93 100644 --- a/validator/client/key_reload_test.go +++ b/validator/client/key_reload_test.go @@ -6,12 +6,12 @@ import ( "github.com/golang/mock/gomock" "github.com/pkg/errors" - types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/bls" "github.com/prysmaticlabs/prysm/shared/mock" "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" + "github.com/prysmaticlabs/prysm/validator/client/testutil" logTest "github.com/sirupsen/logrus/hooks/test" ) @@ -42,7 +42,7 @@ func TestValidator_HandleKeyReload(t *testing.T) { genesisTime: 1, } - resp := generateResponse([][]byte{inactivePubKey[:], activePubKey[:]}) + resp := testutil.GenerateMultipleValidatorStatusResponse([][]byte{inactivePubKey[:], activePubKey[:]}) resp.Statuses[0].Status = ethpb.ValidatorStatus_UNKNOWN_STATUS resp.Statuses[1].Status = ethpb.ValidatorStatus_ACTIVE client.EXPECT().MultipleValidatorStatus( @@ -78,7 +78,7 @@ func TestValidator_HandleKeyReload(t *testing.T) { genesisTime: 1, } - resp := generateResponse([][]byte{inactivePubKey[:]}) + resp := testutil.GenerateMultipleValidatorStatusResponse([][]byte{inactivePubKey[:]}) resp.Statuses[0].Status = ethpb.ValidatorStatus_UNKNOWN_STATUS client.EXPECT().MultipleValidatorStatus( gomock.Any(), @@ -122,20 +122,3 @@ func TestValidator_HandleKeyReload(t *testing.T) { assert.ErrorContains(t, "error", err) }) } - -func generateResponse(pubkeys [][]byte) *ethpb.MultipleValidatorStatusResponse { - resp := ðpb.MultipleValidatorStatusResponse{ - PublicKeys: make([][]byte, len(pubkeys)), - Statuses: make([]*ethpb.ValidatorStatusResponse, len(pubkeys)), - Indices: make([]types.ValidatorIndex, len(pubkeys)), - } - for i, key := range pubkeys { - resp.PublicKeys[i] = key - resp.Statuses[i] = ðpb.ValidatorStatusResponse{ - Status: ethpb.ValidatorStatus_UNKNOWN_STATUS, - } - resp.Indices[i] = types.ValidatorIndex(i) - } - - return resp -} diff --git a/validator/client/runner.go b/validator/client/runner.go index fe7a8640811d..1b21187b7a0d 100644 --- a/validator/client/runner.go +++ b/validator/client/runner.go @@ -100,7 +100,7 @@ func run(ctx context.Context, v iface.Validator) { handleAssignmentError(err, headSlot) } - accountsChangedChan := make(chan [][48]byte) + accountsChangedChan := make(chan [][48]byte, 1) sub := v.GetKeymanager().SubscribeAccountChanges(accountsChangedChan) for { slotCtx, cancel := context.WithCancel(ctx) diff --git a/validator/client/testutil/helper.go b/validator/client/testutil/helper.go index 835eef88de18..862e79603e7b 100644 --- a/validator/client/testutil/helper.go +++ b/validator/client/testutil/helper.go @@ -1,6 +1,28 @@ package testutil -import "github.com/prysmaticlabs/prysm/shared/bytesutil" +import ( + types "github.com/prysmaticlabs/eth2-types" + ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/bytesutil" +) // ActiveKey represents a public key whose status is ACTIVE. var ActiveKey = bytesutil.ToBytes48([]byte("active")) + +// GenerateMultipleValidatorStatusResponse prepares a response from the passed in keys. +func GenerateMultipleValidatorStatusResponse(pubkeys [][]byte) *ethpb.MultipleValidatorStatusResponse { + resp := ðpb.MultipleValidatorStatusResponse{ + PublicKeys: make([][]byte, len(pubkeys)), + Statuses: make([]*ethpb.ValidatorStatusResponse, len(pubkeys)), + Indices: make([]types.ValidatorIndex, len(pubkeys)), + } + for i, key := range pubkeys { + resp.PublicKeys[i] = key + resp.Statuses[i] = ðpb.ValidatorStatusResponse{ + Status: ethpb.ValidatorStatus_UNKNOWN_STATUS, + } + resp.Indices[i] = types.ValidatorIndex(i) + } + + return resp +} diff --git a/validator/client/validator.go b/validator/client/validator.go index c0d4a90f36c7..dc681a2c5069 100644 --- a/validator/client/validator.go +++ b/validator/client/validator.go @@ -68,7 +68,7 @@ type validator struct { highestValidSlot types.Slot domainDataCache *ristretto.Cache aggregatedSlotCommitteeIDCache *lru.Cache - ticker *slotutil.SlotTicker + ticker slotutil.Ticker prevBalance map[[48]byte]uint64 duties *ethpb.DutiesResponse startBalances map[[48]byte]uint64 diff --git a/validator/client/wait_for_activation.go b/validator/client/wait_for_activation.go index 0da40c7c5019..997ea67f5450 100644 --- a/validator/client/wait_for_activation.go +++ b/validator/client/wait_for_activation.go @@ -12,6 +12,7 @@ import ( "github.com/prysmaticlabs/prysm/shared/params" "github.com/prysmaticlabs/prysm/shared/slotutil" "github.com/prysmaticlabs/prysm/shared/traceutil" + "github.com/prysmaticlabs/prysm/validator/keymanager/remote" "go.opencensus.io/trace" ) @@ -24,7 +25,7 @@ import ( func (v *validator) WaitForActivation(ctx context.Context, accountsChangedChan chan [][48]byte) error { // Monitor the key manager for updates. if accountsChangedChan == nil { - accountsChangedChan = make(chan [][48]byte) + accountsChangedChan = make(chan [][48]byte, 1) sub := v.GetKeymanager().SubscribeAccountChanges(accountsChangedChan) defer func() { sub.Unsubscribe() @@ -87,47 +88,87 @@ func (v *validator) waitForActivation(ctx context.Context, accountsChangedChan < time.Sleep(time.Second * time.Duration(mathutil.Min(uint64(attempts), 60))) return v.waitForActivation(incrementRetries(ctx), accountsChangedChan) } - for { - select { - case <-accountsChangedChan: - // Accounts (keys) changed, restart the process. - return v.waitForActivation(ctx, accountsChangedChan) - default: - res, err := stream.Recv() - // If the stream is closed, we stop the loop. - if errors.Is(err, io.EOF) { - break - } - // If context is canceled we return from the function. + + remoteKm, ok := v.keyManager.(remote.RemoteKeymanager) + if ok { + for range v.NextSlot() { if ctx.Err() == context.Canceled { - return errors.Wrap(ctx.Err(), "context has been canceled so shutting down the loop") + return errors.Wrap(ctx.Err(), "context canceled, not waiting for activation anymore") } + + validatingKeys, err = remoteKm.ReloadPublicKeys(ctx) if err != nil { - traceutil.AnnotateError(span, err) - attempts := streamAttempts(ctx) - log.WithError(err).WithField("attempts", attempts). - Error("Stream broken while waiting for activation. Reconnecting...") - // Reconnection attempt backoff, up to 60s. - time.Sleep(time.Second * time.Duration(mathutil.Min(uint64(attempts), 60))) - return v.waitForActivation(incrementRetries(ctx), accountsChangedChan) + return errors.Wrap(err, msgCouldNotFetchKeys) } - - statuses := make([]*validatorStatus, len(res.Statuses)) - for i, s := range res.Statuses { + statusRequestKeys := make([][]byte, len(validatingKeys)) + for i := range validatingKeys { + statusRequestKeys[i] = validatingKeys[i][:] + } + resp, err := v.validatorClient.MultipleValidatorStatus(ctx, ðpb.MultipleValidatorStatusRequest{ + PublicKeys: statusRequestKeys, + }) + if err != nil { + return err + } + statuses := make([]*validatorStatus, len(resp.Statuses)) + for i, s := range resp.Statuses { statuses[i] = &validatorStatus{ - publicKey: s.PublicKey, - status: s.Status, - index: s.Index, + publicKey: resp.PublicKeys[i], + status: s, + index: resp.Indices[i], } } + valActivated := v.checkAndLogValidatorStatus(statuses) if valActivated { logActiveValidatorStatus(statuses) - } else { - continue + break + } + } + } else { + for { + select { + case <-accountsChangedChan: + // Accounts (keys) changed, restart the process. + return v.waitForActivation(ctx, accountsChangedChan) + default: + res, err := stream.Recv() + // If the stream is closed, we stop the loop. + if errors.Is(err, io.EOF) { + break + } + // If context is canceled we return from the function. + if ctx.Err() == context.Canceled { + return errors.Wrap(ctx.Err(), "context has been canceled so shutting down the loop") + } + if err != nil { + traceutil.AnnotateError(span, err) + attempts := streamAttempts(ctx) + log.WithError(err).WithField("attempts", attempts). + Error("Stream broken while waiting for activation. Reconnecting...") + // Reconnection attempt backoff, up to 60s. + time.Sleep(time.Second * time.Duration(mathutil.Min(uint64(attempts), 60))) + return v.waitForActivation(incrementRetries(ctx), accountsChangedChan) + } + + statuses := make([]*validatorStatus, len(res.Statuses)) + for i, s := range res.Statuses { + statuses[i] = &validatorStatus{ + publicKey: s.PublicKey, + status: s.Status, + index: s.Index, + } + } + + valActivated := v.checkAndLogValidatorStatus(statuses) + if valActivated { + logActiveValidatorStatus(statuses) + } else { + continue + } } + break } - break } v.ticker = slotutil.NewSlotTicker(time.Unix(int64(v.genesisTime), 0), params.BeaconConfig().SecondsPerSlot) diff --git a/validator/client/wait_for_activation_test.go b/validator/client/wait_for_activation_test.go index cb6d1d3e5883..5beac6ec2ebd 100644 --- a/validator/client/wait_for_activation_test.go +++ b/validator/client/wait_for_activation_test.go @@ -8,13 +8,18 @@ import ( "github.com/golang/mock/gomock" "github.com/pkg/errors" + types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/prysm/shared/bls" + "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/mock" + slotutilmock "github.com/prysmaticlabs/prysm/shared/slotutil/testing" "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" walletMock "github.com/prysmaticlabs/prysm/validator/accounts/testing" + "github.com/prysmaticlabs/prysm/validator/client/testutil" "github.com/prysmaticlabs/prysm/validator/keymanager/derived" + "github.com/prysmaticlabs/prysm/validator/keymanager/remote" constant "github.com/prysmaticlabs/prysm/validator/testing" logTest "github.com/sirupsen/logrus/hooks/test" "github.com/tyler-smith/go-bip39" @@ -378,3 +383,79 @@ func TestWaitForActivation_AccountsChanged(t *testing.T) { assert.LogsContain(t, hook, "Validator activated") }) } + +func TestWaitForActivation_RemoteKeymanager(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mock.NewMockBeaconNodeValidatorClient(ctrl) + stream := mock.NewMockBeaconNodeValidator_WaitForActivationClient(ctrl) + client.EXPECT().WaitForActivation( + gomock.Any(), + gomock.Any(), + ).Return(stream, nil /* err */).AnyTimes() + + inactiveKey := bytesutil.ToBytes48([]byte("inactive")) + activeKey := bytesutil.ToBytes48([]byte("active")) + km := &remote.MockKeymanager{ + PublicKeys: [][48]byte{inactiveKey, activeKey}, + } + slot := types.Slot(0) + + t.Run("activated", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + hook := logTest.NewGlobal() + + tickerChan := make(chan types.Slot) + ticker := &slotutilmock.MockTicker{ + Channel: tickerChan, + } + v := validator{ + validatorClient: client, + keyManager: km, + ticker: ticker, + } + go func() { + tickerChan <- slot + // Cancel after timeout to avoid waiting on channel forever in case test goes wrong. + time.Sleep(time.Second) + cancel() + }() + + resp := testutil.GenerateMultipleValidatorStatusResponse([][]byte{inactiveKey[:], activeKey[:]}) + resp.Statuses[0].Status = ethpb.ValidatorStatus_UNKNOWN_STATUS + resp.Statuses[1].Status = ethpb.ValidatorStatus_ACTIVE + client.EXPECT().MultipleValidatorStatus( + gomock.Any(), + ðpb.MultipleValidatorStatusRequest{ + PublicKeys: [][]byte{inactiveKey[:], activeKey[:]}, + }, + ).Return(resp, nil /* err */) + + err := v.waitForActivation(ctx, nil /* accountsChangedChan */) + require.NoError(t, err) + assert.LogsContain(t, hook, "Waiting for deposit to be observed by beacon node") + assert.LogsContain(t, hook, "Validator activated") + }) + + t.Run("cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + tickerChan := make(chan types.Slot) + ticker := &slotutilmock.MockTicker{ + Channel: tickerChan, + } + v := validator{ + validatorClient: client, + keyManager: km, + ticker: ticker, + } + go func() { + cancel() + tickerChan <- slot + }() + + err := v.waitForActivation(ctx, nil /* accountsChangedChan */) + assert.ErrorContains(t, "context canceled, not waiting for activation anymore", err) + }) +} diff --git a/validator/keymanager/BUILD.bazel b/validator/keymanager/BUILD.bazel index fcc8461609b9..bc7b169c3eab 100644 --- a/validator/keymanager/BUILD.bazel +++ b/validator/keymanager/BUILD.bazel @@ -3,7 +3,10 @@ load("@prysm//tools/go:def.bzl", "go_library") go_library( name = "go_default_library", - srcs = ["types.go"], + srcs = [ + "constants.go", + "types.go", + ], importpath = "github.com/prysmaticlabs/prysm/validator/keymanager", visibility = [ "//tools/keystores:__pkg__", diff --git a/validator/keymanager/constants.go b/validator/keymanager/constants.go new file mode 100644 index 000000000000..dc0af7f43856 --- /dev/null +++ b/validator/keymanager/constants.go @@ -0,0 +1,3 @@ +package keymanager + +const KeysReloaded = "Reloaded validator keys into keymanager" diff --git a/validator/keymanager/imported/refresh.go b/validator/keymanager/imported/refresh.go index 1f33d69bf2c9..55115fc7d84a 100644 --- a/validator/keymanager/imported/refresh.go +++ b/validator/keymanager/imported/refresh.go @@ -13,6 +13,7 @@ import ( "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/featureconfig" "github.com/prysmaticlabs/prysm/shared/fileutil" + "github.com/prysmaticlabs/prysm/validator/keymanager" keystorev4 "github.com/wealdtech/go-eth2-wallet-encryptor-keystorev4" ) @@ -118,7 +119,7 @@ func (km *Keymanager) reloadAccountsFromKeystore(keystore *AccountsKeystoreRepre if err := km.initializeKeysCachesFromKeystore(); err != nil { return err } - log.Info("Reloaded validator keys into keymanager") + log.Info(keymanager.KeysReloaded) km.accountsChangedFeed.Send(pubKeys) return nil } diff --git a/validator/keymanager/remote/BUILD.bazel b/validator/keymanager/remote/BUILD.bazel index 70aec92eec1e..4fa02c133560 100644 --- a/validator/keymanager/remote/BUILD.bazel +++ b/validator/keymanager/remote/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "doc.go", "keymanager.go", "log.go", + "mock_keymanager.go", ], importpath = "github.com/prysmaticlabs/prysm/validator/keymanager/remote", visibility = [ @@ -18,6 +19,7 @@ go_library( "//shared/bls:go_default_library", "//shared/bytesutil:go_default_library", "//shared/event:go_default_library", + "//validator/keymanager:go_default_library", "@com_github_logrusorgru_aurora//:go_default_library", "@com_github_pkg_errors//:go_default_library", "@com_github_sirupsen_logrus//:go_default_library", @@ -34,10 +36,14 @@ go_test( deps = [ "//proto/validator/accounts/v2:go_default_library", "//shared/bls:go_default_library", + "//shared/bytesutil:go_default_library", + "//shared/event:go_default_library", "//shared/mock:go_default_library", "//shared/params:go_default_library", "//shared/testutil/assert:go_default_library", "//shared/testutil/require:go_default_library", + "//validator/keymanager:go_default_library", "@com_github_golang_mock//gomock:go_default_library", + "@com_github_sirupsen_logrus//hooks/test:go_default_library", ], ) diff --git a/validator/keymanager/remote/keymanager.go b/validator/keymanager/remote/keymanager.go index 066959955036..8ac668f26a71 100644 --- a/validator/keymanager/remote/keymanager.go +++ b/validator/keymanager/remote/keymanager.go @@ -1,6 +1,7 @@ package remote import ( + "bytes" "context" "crypto/tls" "crypto/x509" @@ -8,6 +9,7 @@ import ( "fmt" "io" "io/ioutil" + "sort" "strings" "github.com/golang/protobuf/ptypes/empty" @@ -17,6 +19,7 @@ import ( "github.com/prysmaticlabs/prysm/shared/bls" "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/event" + "github.com/prysmaticlabs/prysm/validator/keymanager" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -30,6 +33,12 @@ var ( ErrSigningDenied = errors.New("signing request was denied by remote server") ) +// RemoteKeymanager defines the interface for remote Prysm wallets. +type RemoteKeymanager interface { + keymanager.IKeymanager + ReloadPublicKeys(ctx context.Context) ([][48]byte, error) +} + // KeymanagerOpts for a remote keymanager. type KeymanagerOpts struct { RemoteCertificate *CertificateConfig `json:"remote_cert"` @@ -55,9 +64,10 @@ type SetupConfig struct { // Keymanager implementation using remote signing keys via gRPC. type Keymanager struct { - opts *KeymanagerOpts - client validatorpb.RemoteSignerClient - accountsByPubkey map[[48]byte]string + opts *KeymanagerOpts + client validatorpb.RemoteSignerClient + orderedPubKeys [][48]byte + accountsChangedFeed *event.Feed } // NewKeymanager instantiates a new imported keymanager from configuration options. @@ -118,9 +128,10 @@ func NewKeymanager(_ context.Context, cfg *SetupConfig) (*Keymanager, error) { } client := validatorpb.NewRemoteSignerClient(conn) k := &Keymanager{ - opts: cfg.Opts, - client: client, - accountsByPubkey: make(map[[48]byte]string), + opts: cfg.Opts, + client: client, + orderedPubKeys: make([][48]byte, 0), + accountsChangedFeed: new(event.Feed), } return k, nil } @@ -196,6 +207,30 @@ func (km *Keymanager) KeymanagerOpts() *KeymanagerOpts { return km.opts } +func (km *Keymanager) ReloadPublicKeys(ctx context.Context) ([][48]byte, error) { + pubKeys, err := km.FetchValidatingPublicKeys(ctx) + if err != nil { + return nil, errors.Wrap(err, "could not reload public keys") + } + + sort.Slice(pubKeys, func(i, j int) bool { return bytes.Compare(pubKeys[i][:], pubKeys[j][:]) == -1 }) + if len(km.orderedPubKeys) != len(pubKeys) { + log.Info(keymanager.KeysReloaded) + km.accountsChangedFeed.Send(pubKeys) + } else { + for i := range km.orderedPubKeys { + if !bytes.Equal(km.orderedPubKeys[i][:], pubKeys[i][:]) { + log.Info(keymanager.KeysReloaded) + km.accountsChangedFeed.Send(pubKeys) + break + } + } + } + + km.orderedPubKeys = pubKeys + return km.orderedPubKeys, nil +} + // FetchValidatingPublicKeys fetches the list of public keys that should be used to validate with. func (km *Keymanager) FetchValidatingPublicKeys(ctx context.Context) ([][48]byte, error) { resp, err := km.client.ListValidatingPublicKeys(ctx, &empty.Empty{}) @@ -224,10 +259,9 @@ func (km *Keymanager) Sign(ctx context.Context, req *validatorpb.SignRequest) (b return bls.SignatureFromBytes(resp.Signature) } -// SubscribeAccountChanges is currently NOT IMPLEMENTED for the remote keymanager. -// INVOKING THIS FUNCTION HAS NO EFFECT! -func (km *Keymanager) SubscribeAccountChanges(_ chan [][48]byte) event.Subscription { - return event.NewSubscription(func(i <-chan struct{}) error { - return nil - }) +// SubscribeAccountChanges creates an event subscription for a channel +// to listen for public key changes at runtime, such as when new validator accounts +// are imported into the keymanager while the validator process is running. +func (km *Keymanager) SubscribeAccountChanges(pubKeysChan chan [][48]byte) event.Subscription { + return km.accountsChangedFeed.Subscribe(pubKeysChan) } diff --git a/validator/keymanager/remote/keymanager_test.go b/validator/keymanager/remote/keymanager_test.go index 0d3ae5cadfd2..958b557360dc 100644 --- a/validator/keymanager/remote/keymanager_test.go +++ b/validator/keymanager/remote/keymanager_test.go @@ -14,10 +14,14 @@ import ( "github.com/golang/mock/gomock" validatorpb "github.com/prysmaticlabs/prysm/proto/validator/accounts/v2" "github.com/prysmaticlabs/prysm/shared/bls" + "github.com/prysmaticlabs/prysm/shared/bytesutil" + "github.com/prysmaticlabs/prysm/shared/event" "github.com/prysmaticlabs/prysm/shared/mock" "github.com/prysmaticlabs/prysm/shared/params" "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" + "github.com/prysmaticlabs/prysm/validator/keymanager" + logTest "github.com/sirupsen/logrus/hooks/test" ) var validClientCert = `-----BEGIN CERTIFICATE----- @@ -267,7 +271,8 @@ func TestRemoteKeymanager_FetchValidatingPublicKeys(t *testing.T) { ctrl := gomock.NewController(t) m := mock.NewMockRemoteSignerClient(ctrl) k := &Keymanager{ - client: m, + client: m, + accountsChangedFeed: new(event.Feed), } // Expect error handling to work. @@ -330,3 +335,79 @@ func TestUnmarshalOptionsFile_DefaultRequireTls(t *testing.T) { assert.NoError(t, err) assert.Equal(t, true, opts.RemoteCertificate.RequireTls) } + +func TestReloadPublicKeys(t *testing.T) { + hook := logTest.NewGlobal() + ctx := context.Background() + ctrl := gomock.NewController(t) + m := mock.NewMockRemoteSignerClient(ctrl) + + k := &Keymanager{ + client: m, + accountsChangedFeed: new(event.Feed), + orderedPubKeys: [][48]byte{bytesutil.ToBytes48([]byte("100"))}, + } + + // Add key + m.EXPECT().ListValidatingPublicKeys( + gomock.Any(), // ctx + gomock.Any(), // epoch + ).Return(&validatorpb.ListPublicKeysResponse{ + // Return keys in reverse order to verify ordering + ValidatingPublicKeys: [][]byte{[]byte("200"), []byte("100")}, + }, nil /* err */) + + keys, err := k.ReloadPublicKeys(ctx) + require.NoError(t, err) + assert.DeepEqual(t, [][48]byte{bytesutil.ToBytes48([]byte("100")), bytesutil.ToBytes48([]byte("200"))}, k.orderedPubKeys) + assert.DeepEqual(t, keys, k.orderedPubKeys) + assert.LogsContain(t, hook, keymanager.KeysReloaded) + + hook.Reset() + + // Remove key + m.EXPECT().ListValidatingPublicKeys( + gomock.Any(), // ctx + gomock.Any(), // epoch + ).Return(&validatorpb.ListPublicKeysResponse{ + ValidatingPublicKeys: [][]byte{[]byte("200")}, + }, nil /* err */) + + keys, err = k.ReloadPublicKeys(ctx) + require.NoError(t, err) + assert.DeepEqual(t, [][48]byte{bytesutil.ToBytes48([]byte("200"))}, k.orderedPubKeys) + assert.DeepEqual(t, keys, k.orderedPubKeys) + assert.LogsContain(t, hook, keymanager.KeysReloaded) + + hook.Reset() + + // Change key + m.EXPECT().ListValidatingPublicKeys( + gomock.Any(), // ctx + gomock.Any(), // epoch + ).Return(&validatorpb.ListPublicKeysResponse{ + ValidatingPublicKeys: [][]byte{[]byte("300")}, + }, nil /* err */) + + keys, err = k.ReloadPublicKeys(ctx) + require.NoError(t, err) + assert.DeepEqual(t, [][48]byte{bytesutil.ToBytes48([]byte("300"))}, k.orderedPubKeys) + assert.DeepEqual(t, keys, k.orderedPubKeys) + assert.LogsContain(t, hook, keymanager.KeysReloaded) + + hook.Reset() + + // No change + m.EXPECT().ListValidatingPublicKeys( + gomock.Any(), // ctx + gomock.Any(), // epoch + ).Return(&validatorpb.ListPublicKeysResponse{ + ValidatingPublicKeys: [][]byte{[]byte("300")}, + }, nil /* err */) + + keys, err = k.ReloadPublicKeys(ctx) + require.NoError(t, err) + assert.DeepEqual(t, [][48]byte{bytesutil.ToBytes48([]byte("300"))}, k.orderedPubKeys) + assert.DeepEqual(t, keys, k.orderedPubKeys) + assert.LogsDoNotContain(t, hook, keymanager.KeysReloaded) +} diff --git a/validator/keymanager/remote/mock_keymanager.go b/validator/keymanager/remote/mock_keymanager.go new file mode 100644 index 000000000000..19911595f991 --- /dev/null +++ b/validator/keymanager/remote/mock_keymanager.go @@ -0,0 +1,29 @@ +package remote + +import ( + "context" + + validatorpb "github.com/prysmaticlabs/prysm/proto/validator/accounts/v2" + "github.com/prysmaticlabs/prysm/shared/bls" + "github.com/prysmaticlabs/prysm/shared/event" +) + +type MockKeymanager struct { + PublicKeys [][48]byte +} + +func (m *MockKeymanager) FetchValidatingPublicKeys(context.Context) ([][48]byte, error) { + return m.PublicKeys, nil +} + +func (*MockKeymanager) Sign(context.Context, *validatorpb.SignRequest) (bls.Signature, error) { + panic("implement me") +} + +func (*MockKeymanager) SubscribeAccountChanges(chan [][48]byte) event.Subscription { + panic("implement me") +} + +func (m *MockKeymanager) ReloadPublicKeys(context.Context) ([][48]byte, error) { + return m.PublicKeys, nil +}