diff --git a/core/message-handling.go b/core/message-handling.go index 52f8f2c..d4b569f 100644 --- a/core/message-handling.go +++ b/core/message-handling.go @@ -164,7 +164,8 @@ func defaultMessageHandlers(id uint32, log messagelog.MessageLog, unicastLogs ma validateRequest := makeRequestValidator(verifyMessageSignature) validatePrepare := makePrepareValidator(n, verifyUI, validateRequest) validateCommit := makeCommitValidator(verifyUI, validatePrepare) - validateMessage := makeMessageValidator(validateRequest, validatePrepare, validateCommit) + validateReqViewChange := makeReqViewChangeValidator(verifyMessageSignature) + validateMessage := makeMessageValidator(validateRequest, validatePrepare, validateCommit, validateReqViewChange) applyCommit := makeCommitApplier(collectCommitment) applyPrepare := makePrepareApplier(id, prepareSeq, collectCommitment, handleGeneratedMessage, stopPrepTimer) @@ -187,7 +188,12 @@ func defaultMessageHandlers(id uint32, log messagelog.MessageLog, unicastLogs ma processViewMessage := makeViewMessageProcessor(viewState, applyPeerMessage) processUIMessage := makeUIMessageProcessor(captureUI, processViewMessage) processEmbedded := makeEmbeddedMessageProcessor(processMessageThunk, logger) - processPeerMessage := makePeerMessageProcessor(processEmbedded, processUIMessage) + + collectReqViewChange := makeReqViewChangeCollector(f) + startViewChange := makeViewChangeStarter(id, viewState, log, handleGeneratedMessage) + processReqViewChange := makeReqViewChangeProcessor(collectReqViewChange, startViewChange) + + processPeerMessage := makePeerMessageProcessor(processEmbedded, processUIMessage, processReqViewChange) processMessage = makeMessageProcessor(processRequest, processPeerMessage) handleOwnMessage = makeOwnMessageHandler(processMessage) handlePeerMessage = makePeerMessageHandler(validateMessage, processMessage) @@ -405,7 +411,7 @@ func makeClientMessageHandler(validateRequest requestValidator, processRequest r // makeMessageValidator constructs an instance of messageValidator // using the supplied abstractions. -func makeMessageValidator(validateRequest requestValidator, validatePrepare prepareValidator, validateCommit commitValidator) messageValidator { +func makeMessageValidator(validateRequest requestValidator, validatePrepare prepareValidator, validateCommit commitValidator, validateReqViewChange reqViewChangeValidator) messageValidator { return func(msg messages.Message) error { switch msg := msg.(type) { case messages.Request: @@ -415,7 +421,7 @@ func makeMessageValidator(validateRequest requestValidator, validatePrepare prep case messages.Commit: return validateCommit(msg) case messages.ReqViewChange: - return fmt.Errorf("not implemented") + return validateReqViewChange(msg) default: panic("Unknown message type") } @@ -437,13 +443,15 @@ func makeMessageProcessor(processRequest requestProcessor, processPeerMessage pe } } -func makePeerMessageProcessor(processEmbedded embeddedMessageProcessor, processUIMessage uiMessageProcessor) peerMessageProcessor { +func makePeerMessageProcessor(processEmbedded embeddedMessageProcessor, processUIMessage uiMessageProcessor, processReqViewChange reqViewChangeProcessor) peerMessageProcessor { return func(msg messages.PeerMessage) (new bool, err error) { processEmbedded(msg) switch msg := msg.(type) { case messages.CertifiedMessage: return processUIMessage(msg) + case messages.ReqViewChange: + return processReqViewChange(msg) default: panic("Unknown message type") } diff --git a/core/message-handling_test.go b/core/message-handling_test.go index 7245480..687ae28 100644 --- a/core/message-handling_test.go +++ b/core/message-handling_test.go @@ -198,11 +198,16 @@ func TestMakeMessageValidator(t *testing.T) { args := mock.MethodCalled("commitValidator", msg) return args.Error(0) } - validateMessage := makeMessageValidator(validateRequest, validatePrepare, validateCommit) + validateReqViewChange := func(msg messages.ReqViewChange) error { + args := mock.MethodCalled("reqViewChangeValidator", msg) + return args.Error(0) + } + validateMessage := makeMessageValidator(validateRequest, validatePrepare, validateCommit, validateReqViewChange) request := messageImpl.NewRequest(0, rand.Uint64(), nil) prepare := messageImpl.NewPrepare(0, 0, request) commit := messageImpl.NewCommit(0, prepare) + rvc := messageImpl.NewReqViewChange(0, rand.Uint64()) t.Run("UnknownMessageType", func(t *testing.T) { msg := mock_messages.NewMockMessage(ctrl) @@ -235,6 +240,15 @@ func TestMakeMessageValidator(t *testing.T) { err = validateMessage(commit) assert.NoError(t, err) }) + t.Run("ReqViewChange", func(t *testing.T) { + mock.On("reqViewChangeValidator", rvc).Return(fmt.Errorf("Error")).Once() + err := validateMessage(rvc) + assert.Error(t, err, "Invalid ReqViewChange") + + mock.On("reqViewChangeValidator", rvc).Return(nil).Once() + err = validateMessage(rvc) + assert.NoError(t, err) + }) } func TestMakeMessageProcessor(t *testing.T) { @@ -311,7 +325,11 @@ func TestMakePeerMessageProcessor(t *testing.T) { args := mock.MethodCalled("uiMessageProcessor", msg) return args.Bool(0), args.Error(1) } - process := makePeerMessageProcessor(processEmbedded, processUIMessage) + processReqViewChange := func(msg messages.ReqViewChange) (new bool, err error) { + args := mock.MethodCalled("reqViewChangeProcessor", msg) + return args.Bool(0), args.Error(1) + } + process := makePeerMessageProcessor(processEmbedded, processUIMessage, processReqViewChange) t.Run("UnknownMessageType", func(t *testing.T) { msg := mock_messages.NewMockPeerMessage(ctrl) @@ -350,6 +368,26 @@ func TestMakePeerMessageProcessor(t *testing.T) { assert.False(t, new) }) + t.Run("ReqViewChange", func(t *testing.T) { + msg := messageImpl.NewReqViewChange(rand.Uint32(), rand.Uint64()) + + mock.On("embeddedMessageProcessor", msg).Once() + mock.On("reqViewChangeProcessor", msg).Return(false, fmt.Errorf("Error")).Once() + _, err := process(msg) + assert.Error(t, err, "Failed to finish processing certified message") + + mock.On("embeddedMessageProcessor", msg).Once() + mock.On("reqViewChangeProcessor", msg).Return(true, nil).Once() + new, err := process(msg) + assert.NoError(t, err) + assert.True(t, new) + + mock.On("embeddedMessageProcessor", msg).Once() + mock.On("reqViewChangeProcessor", msg).Return(false, nil).Once() + new, err = process(msg) + assert.NoError(t, err) + assert.False(t, new) + }) } func TestMakeEmbeddedMessageProcessor(t *testing.T) { diff --git a/core/req-view-change.go b/core/req-view-change.go new file mode 100644 index 0000000..611f332 --- /dev/null +++ b/core/req-view-change.go @@ -0,0 +1,138 @@ +// Copyright (c) 2021 NEC Laboratories Europe GmbH. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package minbft + +import ( + "fmt" + "sync" + + "github.com/hyperledger-labs/minbft/core/internal/messagelog" + "github.com/hyperledger-labs/minbft/core/internal/viewstate" + "github.com/hyperledger-labs/minbft/messages" +) + +// reqViewChangeValidator validates a ReqViewChangeMessage. +// +// It authenticates and checks the supplied message for internal +// consistency. It does not use replica's current state and has no +// side-effect. It is safe to invoke concurrently. +type reqViewChangeValidator func(rvc messages.ReqViewChange) error + +// reqViewChangeProcessor processes a valid ReqViewChange message. +// +// It continues processing of the supplied message. The return value +// new indicates if the message had any effect. It is safe to invoke +// concurrently. +type reqViewChangeProcessor func(rvc messages.ReqViewChange) (new bool, err error) + +// reqViewChangeCollector collects view change requests. +// +// The supplied ReqViewChange message is assumed to be valid. Once the +// threshold of matching ReqViewChange messages from distinct replicas +// referring to the next view has been reached, it returns a +// view-change certificate comprised of those messages. The return +// value new indicates if the message had any effect. +type reqViewChangeCollector func(rvc messages.ReqViewChange) (new bool, _ messages.ViewChangeCert) + +// viewChangeStarter attempts to start view change. +// +// It proceeds to trigger view change with the supplied expected new +// view number justified by the supplied view-change certificate +// unless the replica cannot transition to that view anymore. +type viewChangeStarter func(newView uint64, vcCert messages.ViewChangeCert) (ok bool, err error) + +func makeReqViewChangeValidator(verifySignature messageSignatureVerifier) reqViewChangeValidator { + return func(rvc messages.ReqViewChange) error { + if rvc.NewView() < 1 { + return fmt.Errorf("Invalid new view number") + } + + if err := verifySignature(rvc); err != nil { + return fmt.Errorf("Signature is not valid: %s", err) + } + + return nil + } +} + +func makeReqViewChangeProcessor(collect reqViewChangeCollector, startViewChange viewChangeStarter) reqViewChangeProcessor { + var lock sync.Mutex + + return func(rvc messages.ReqViewChange) (new bool, err error) { + lock.Lock() + defer lock.Unlock() + + new, vcCert := collect(rvc) + if vcCert == nil { + return new, nil + } + + return startViewChange(rvc.NewView(), vcCert) + } +} + +func makeReqViewChangeCollector(f uint32) reqViewChangeCollector { + var ( + view uint64 + collected = make(messages.ViewChangeCert, 0, f+1) + replicas = make(map[uint32]bool, f+1) + ) + + return func(rvc messages.ReqViewChange) (new bool, vcCert messages.ViewChangeCert) { + replicaID := rvc.ReplicaID() + + if rvc.NewView() != view+1 || replicas[replicaID] { + return false, nil + } + + collected = append(collected, rvc) + replicas[replicaID] = true + + if uint32(len(collected)) <= f { + return true, nil + } + + vcCert = collected + collected = make(messages.ViewChangeCert, 0, f+1) + replicas = make(map[uint32]bool, f+1) + view++ + + return true, vcCert + } +} + +func makeViewChangeStarter(id uint32, viewState viewstate.State, log messagelog.MessageLog, handleGeneratedMessage generatedMessageHandler) viewChangeStarter { + return func(newView uint64, vcCert messages.ViewChangeCert) (ok bool, err error) { + ok, release := viewState.AdvanceExpectedView(newView) + if !ok { + return false, nil + } + defer release() + + var msgs messages.MessageLog + for _, m := range log.Messages() { + if m, ok := m.(messages.CertifiedMessage); ok { + msgs = append(msgs, m) + } + } + log.Reset(nil) + + // TODO: start view-change timer + + handleGeneratedMessage(messageImpl.NewViewChange(id, newView, msgs, vcCert)) + + return true, nil + } +} diff --git a/core/req-view-change_test.go b/core/req-view-change_test.go new file mode 100644 index 0000000..6b2fb62 --- /dev/null +++ b/core/req-view-change_test.go @@ -0,0 +1,195 @@ +// Copyright (c) 2021 NEC Laboratories Europe GmbH. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package minbft + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/golang/mock/gomock" + mock_messagelog "github.com/hyperledger-labs/minbft/core/internal/messagelog/mocks" + mock_viewstate "github.com/hyperledger-labs/minbft/core/internal/viewstate/mocks" + "github.com/hyperledger-labs/minbft/messages" + "github.com/hyperledger-labs/minbft/usig" + "github.com/stretchr/testify/assert" + testifymock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +func TestMakeReqViewChangeValidator(t *testing.T) { + mock := new(testifymock.Mock) + defer mock.AssertExpectations(t) + + verify := func(msg messages.SignedMessage) error { + args := mock.MethodCalled("messageSignatureVerifier", msg) + return args.Error(0) + } + validate := makeReqViewChangeValidator(verify) + + rvc := messageImpl.NewReqViewChange(rand.Uint32(), 0) + err := validate(rvc) + assert.Error(t, err, "Invalid new view number") + + rvc = messageImpl.NewReqViewChange(rand.Uint32(), 1+uint64(rand.Int63())) + + mock.On("messageSignatureVerifier", rvc).Return(fmt.Errorf("invalid signature")).Once() + err = validate(rvc) + assert.Error(t, err, "Invalid signature") + + mock.On("messageSignatureVerifier", rvc).Return(nil).Once() + err = validate(rvc) + assert.NoError(t, err) +} + +func TestMakeReqViewChangeProcessor(t *testing.T) { + mock := new(testifymock.Mock) + defer mock.AssertExpectations(t) + + collect := func(rvc messages.ReqViewChange) (new bool, _ messages.ViewChangeCert) { + args := mock.MethodCalled("reqViewChangeCollector", rvc) + return args.Bool(0), args.Get(1).(messages.ViewChangeCert) + } + start := func(newView uint64, vcCert messages.ViewChangeCert) (ok bool, _ error) { + args := mock.MethodCalled("viewChangeStarter", newView, vcCert) + return args.Bool(0), args.Error(1) + } + process := makeReqViewChangeProcessor(collect, start) + + newView := rand.Uint64() + rvc := messageImpl.NewReqViewChange(1, newView) + cert := messages.ViewChangeCert{rvc, messageImpl.NewReqViewChange(2, newView)} + + mock.On("reqViewChangeCollector", rvc).Return(false, messages.ViewChangeCert(nil)).Once() + new, err := process(rvc) + assert.False(t, new) + assert.NoError(t, err) + + mock.On("reqViewChangeCollector", rvc).Return(true, messages.ViewChangeCert(nil)).Once() + new, err = process(rvc) + assert.True(t, new) + assert.NoError(t, err) + + mock.On("reqViewChangeCollector", rvc).Return(true, cert).Once() + mock.On("viewChangeStarter", newView, cert).Return(false, nil).Once() + new, err = process(rvc) + assert.False(t, new) + assert.NoError(t, err) + + mock.On("reqViewChangeCollector", rvc).Return(true, cert).Once() + mock.On("viewChangeStarter", newView, cert).Return(true, nil).Once() + new, err = process(rvc) + assert.True(t, new) + assert.NoError(t, err) +} + +func TestMakeReqViewChangeCollector(t *testing.T) { + const f = 1 + + var cases []struct { + ID uint32 + NewView uint64 + New bool + Cert []int + } + casesYAML := []byte(` +- {id: 0, newview: 1, new: y, cert: [] } #0 +- {id: 0, newview: 1, new: n, cert: [] } #1 +- {id: 1, newview: 1, new: y, cert: [0, 2]} #2 +- {id: 2, newview: 1, new: n, cert: [] } #3 +- {id: 1, newview: 2, new: y, cert: [] } #4 +- {id: 0, newview: 3, new: n, cert: [] } #5 +- {id: 2, newview: 2, new: y, cert: [4, 6]} #6 +`) + if err := yaml.UnmarshalStrict(casesYAML, &cases); err != nil { + t.Fatal(err) + } + + var msgs []messages.ReqViewChange + var certs []messages.ViewChangeCert + for _, c := range cases { + rvc := messageImpl.NewReqViewChange(c.ID, c.NewView) + msgs = append(msgs, rvc) + + var cert messages.ViewChangeCert + for _, i := range c.Cert { + cert = append(cert, msgs[i]) + } + certs = append(certs, cert) + } + + collect := makeReqViewChangeCollector(f) + for i, c := range cases { + desc := fmt.Sprintf("Case #%d", i) + new, cert := collect(msgs[i]) + require.Equal(t, c.New, new, desc) + require.Equal(t, certs[i], cert, desc) + } +} + +func TestMakeViewChangeStarter(t *testing.T) { + mock := new(testifymock.Mock) + defer mock.AssertExpectations(t) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + id := rand.Uint32() + viewState := mock_viewstate.NewMockState(ctrl) + log := mock_messagelog.NewMockMessageLog(ctrl) + handleGenerated := func(msg messages.ReplicaMessage) { + mock.MethodCalled("generatedMessageHandler", msg) + } + start := makeViewChangeStarter(id, viewState, log, handleGenerated) + + newView := rand.Uint64() + cert := messages.ViewChangeCert{ + messageImpl.NewReqViewChange(1, newView), + messageImpl.NewReqViewChange(2, newView), + } + msgs := []messages.Message{} + viewLog := messages.MessageLog{} + for cv := uint64(1); cv <= 2; cv++ { + prep := messageImpl.NewPrepare(0, 0, messageImpl.NewRequest(0, cv, randBytes())) + prep.SetUI(&usig.UI{Counter: cv, Cert: randBytes()}) + comm := messageImpl.NewCommit(1, prep) + comm.SetUI(&usig.UI{Counter: cv, Cert: randBytes()}) + msgs = append(msgs, comm) + viewLog = append(viewLog, comm) + } + req := messageImpl.NewRequest(0, 1, randBytes()) + msgs = append(msgs, nil) + copy(msgs[2:], msgs[1:]) + msgs[1] = req + + viewState.EXPECT().AdvanceExpectedView(newView).Return(false, nil) + ok, err := start(newView, cert) + assert.False(t, ok) + assert.NoError(t, err) + + viewState.EXPECT().AdvanceExpectedView(newView).Return(true, func() { + mock.MethodCalled("viewReleaser") + }).AnyTimes() + + vc := messageImpl.NewViewChange(id, newView, viewLog, cert) + log.EXPECT().Messages().Return(msgs) + log.EXPECT().Reset(nil) + mock.On("generatedMessageHandler", vc).Once() + mock.On("viewReleaser").Once() + ok, err = start(newView, cert) + assert.True(t, ok) + assert.NoError(t, err) +}