From 230a013c1dc7e364256c583cf700a51fe256b5c8 Mon Sep 17 00:00:00 2001 From: Denis Rechkunov Date: Tue, 9 Aug 2022 09:52:30 +0200 Subject: [PATCH] Implement the shipper gRPC server (#76) All the endpoints are actually functioning now. Following use-cases are implemented: * Publish a full batch of events * Publish a batch partially based on the queue capacity * Return an error if the UUID of server does not match the client request * Request or subscribe to persisted index change --- NOTICE.txt | 64 +++--- go.mod | 4 +- go.sum | 4 +- queue/queue.go | 4 +- server/controller_client_test.go | 112 ++++++++++ server/run.go | 9 +- server/server.go | 193 ++++++++++++++--- server/server_test.go | 355 +++++++++++++++++++++++-------- 8 files changed, 585 insertions(+), 160 deletions(-) create mode 100644 server/controller_client_test.go diff --git a/NOTICE.txt b/NOTICE.txt index 2ccda56..a84b4fd 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -478,11 +478,11 @@ Contents of probable licence file $GOMODCACHE/github.com/elastic/elastic-agent-l -------------------------------------------------------------------------------- Dependency : github.com/elastic/elastic-agent-shipper-client -Version: v0.3.0 +Version: v0.4.0 Licence type (autodetected): Elastic -------------------------------------------------------------------------------- -Contents of probable licence file $GOMODCACHE/github.com/elastic/elastic-agent-shipper-client@v0.3.0/LICENSE.txt: +Contents of probable licence file $GOMODCACHE/github.com/elastic/elastic-agent-shipper-client@v0.4.0/LICENSE.txt: Elastic License 2.0 @@ -790,6 +790,36 @@ Contents of probable licence file $GOMODCACHE/github.com/elastic/go-ucfg@v0.8.6/ limitations under the License. +-------------------------------------------------------------------------------- +Dependency : github.com/gofrs/uuid +Version: v4.2.0+incompatible +Licence type (autodetected): MIT +-------------------------------------------------------------------------------- + +Contents of probable licence file $GOMODCACHE/github.com/gofrs/uuid@v4.2.0+incompatible/LICENSE: + +Copyright (C) 2013-2018 by Maxim Bublis + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + -------------------------------------------------------------------------------- Dependency : github.com/magefile/mage Version: v1.13.0 @@ -25123,36 +25153,6 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --------------------------------------------------------------------------------- -Dependency : github.com/gofrs/uuid -Version: v4.2.0+incompatible -Licence type (autodetected): MIT --------------------------------------------------------------------------------- - -Contents of probable licence file $GOMODCACHE/github.com/gofrs/uuid@v4.2.0+incompatible/LICENSE: - -Copyright (C) 2013-2018 by Maxim Bublis - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - -------------------------------------------------------------------------------- Dependency : github.com/gogo/protobuf Version: v1.3.2 diff --git a/go.mod b/go.mod index b386ea9..849bde9 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,9 @@ require ( require ( github.com/elastic/beats/v7 v7.0.0-alpha2.0.20220722175030-7cb39607b349 github.com/elastic/elastic-agent-client/v7 v7.0.0-20220607160924-1a71765a8bbe - github.com/elastic/elastic-agent-shipper-client v0.3.0 + github.com/elastic/elastic-agent-shipper-client v0.4.0 github.com/elastic/go-ucfg v0.8.6 + github.com/gofrs/uuid v4.2.0+incompatible github.com/magefile/mage v1.13.0 github.com/stretchr/testify v1.7.1 go.elastic.co/go-licence-detector v0.5.0 @@ -32,7 +33,6 @@ require ( github.com/elastic/go-windows v1.0.1 // indirect github.com/fatih/color v1.13.0 // indirect github.com/gobuffalo/here v0.6.0 // indirect - github.com/gofrs/uuid v4.2.0+incompatible // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/licenseclassifier v0.0.0-20200402202327-879cb1424de0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index b3e4e5f..2460169 100644 --- a/go.sum +++ b/go.sum @@ -501,8 +501,8 @@ github.com/elastic/elastic-agent-libs v0.2.7/go.mod h1:chO3rtcLyGlKi9S0iGVZhYCzD github.com/elastic/elastic-agent-libs v0.2.9 h1:7jOCqNqEWG0kJb3fa8/SC6beSiys1TmAylH9+hWTnrM= github.com/elastic/elastic-agent-libs v0.2.9/go.mod h1:chO3rtcLyGlKi9S0iGVZhYCzDfdDsAQYBc+ui588AFE= github.com/elastic/elastic-agent-shipper-client v0.2.0/go.mod h1:OyI2W+Mv3JxlkEF3OeT7K0dbuxvwew8ke2Cf4HpLa9Q= -github.com/elastic/elastic-agent-shipper-client v0.3.0 h1:Ec/r08WLB6tfR95lW9JUTCUYD+HcqQdK7MR7n+0ax4s= -github.com/elastic/elastic-agent-shipper-client v0.3.0/go.mod h1:OyI2W+Mv3JxlkEF3OeT7K0dbuxvwew8ke2Cf4HpLa9Q= +github.com/elastic/elastic-agent-shipper-client v0.4.0 h1:nsTJF9oo4RHLl+zxFUZqNHaE86C6Ba5aImfegcEf6Sk= +github.com/elastic/elastic-agent-shipper-client v0.4.0/go.mod h1:OyI2W+Mv3JxlkEF3OeT7K0dbuxvwew8ke2Cf4HpLa9Q= github.com/elastic/elastic-agent-system-metrics v0.4.3/go.mod h1:tF/f9Off38nfzTZHIVQ++FkXrDm9keFhFpJ+3pQ00iI= github.com/elastic/elastic-transport-go/v8 v8.1.0/go.mod h1:87Tcz8IVNe6rVSLdBux1o/PEItLtyabHU3naC7IoqKI= github.com/elastic/go-concert v0.2.0/go.mod h1:HWjpO3IAEJUxOeaJOWXWEp7imKd27foxz9V5vegC/38= diff --git a/queue/queue.go b/queue/queue.go index d4e54e0..e997321 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -77,8 +77,8 @@ func (queue *Queue) Get(eventCount int) (beatsqueue.Batch, error) { return queue.eventQueue.Get(eventCount) } -func (queue *Queue) Close() { - queue.eventQueue.Close() +func (queue *Queue) Close() error { + return queue.eventQueue.Close() } func (queue *Queue) AcceptedIndex() EntryID { diff --git a/server/controller_client_test.go b/server/controller_client_test.go new file mode 100644 index 0000000..93c613e --- /dev/null +++ b/server/controller_client_test.go @@ -0,0 +1,112 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package server + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/elastic/elastic-agent-client/v7/pkg/client" + "github.com/elastic/elastic-agent-client/v7/pkg/client/mock" + "github.com/elastic/elastic-agent-client/v7/pkg/proto" +) + +func TestAgentControl(t *testing.T) { + unitOneID := mock.NewID() + + token := mock.NewID() + var gotConfig, gotHealthy, gotStopped bool + + var mut sync.Mutex + + t.Logf("Creating mock server") + srv := mock.StubServerV2{ + CheckinV2Impl: func(observed *proto.CheckinObserved) *proto.CheckinExpected { + mut.Lock() + defer mut.Unlock() + if observed.Token == token { + if len(observed.Units) > 0 { + t.Logf("Current unit state is: %v", observed.Units[0].State) + } + + // initial checkin + if len(observed.Units) == 0 || observed.Units[0].State == proto.State_STARTING { + gotConfig = true + return &proto.CheckinExpected{ + Units: []*proto.UnitExpected{ + { + Id: unitOneID, + Type: proto.UnitType_OUTPUT, + ConfigStateIdx: 1, + Config: `{"logging": {"level": "debug"}}`, // hack to make my life easier + State: proto.State_HEALTHY, + }, + }, + } + } else if observed.Units[0].State == proto.State_HEALTHY { + gotHealthy = true + //shutdown + return &proto.CheckinExpected{ + Units: []*proto.UnitExpected{ + { + Id: unitOneID, + Type: proto.UnitType_OUTPUT, + ConfigStateIdx: 1, + Config: "{}", + State: proto.State_STOPPED, + }, + }, + } + } else if observed.Units[0].State == proto.State_STOPPED { + gotStopped = true + // remove the unit? I think? + return &proto.CheckinExpected{ + Units: nil, + } + } + + } + + //gotInvalid = true + return nil + }, + ActionImpl: func(response *proto.ActionResponse) error { + + return nil + }, + ActionsChan: make(chan *mock.PerformAction, 100), + } // end of srv declaration + + require.NoError(t, srv.Start()) + defer srv.Stop() + + t.Logf("creating client") + // connect with client + validClient := client.NewV2(fmt.Sprintf(":%d", srv.Port), token, client.VersionInfo{ + Name: "program", + Version: "v1.0.0", + Meta: map[string]string{ + "key": "value", + }, + }, grpc.WithTransportCredentials(insecure.NewCredentials())) + + t.Logf("starting shipper controller") + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + err := runController(ctx, validClient) + assert.NoError(t, err) + + assert.True(t, gotConfig, "config state") + assert.True(t, gotHealthy, "healthy state") + assert.True(t, gotStopped, "stopped state") +} diff --git a/server/run.go b/server/run.go index 83c91e2..28793c2 100644 --- a/server/run.go +++ b/server/run.go @@ -104,11 +104,11 @@ func (c *clientHandler) Run(cfg config.ShipperConfig, unit *client.Unit) error { opts = []grpc.ServerOption{grpc.Creds(creds)} } grpcServer := grpc.NewServer(opts...) - r := shipperServer{ - logger: log, - queue: queue, + shipperServer, err := NewShipperServer(queue) + if err != nil { + return fmt.Errorf("failed to initialise the server: %w", err) } - pb.RegisterProducerServer(grpcServer, r) + pb.RegisterProducerServer(grpcServer, shipperServer) shutdownFunc := func() { grpcServer.GracefulStop() @@ -118,6 +118,7 @@ func (c *clientHandler) Run(cfg config.ShipperConfig, unit *client.Unit) error { // We call Wait to give it a chance to finish with events // it has already read. out.Wait() + shipperServer.Close() } handleShutdown(shutdownFunc, c.shutdownInit) log.Debugf("gRPC server is listening on port %d", cfg.Port) diff --git a/server/server.go b/server/server.go index 3d23cb9..7f508bd 100644 --- a/server/server.go +++ b/server/server.go @@ -6,50 +6,187 @@ package server import ( "context" + "errors" + "fmt" + "io" + "sync" + "time" "github.com/elastic/elastic-agent-libs/logp" + pb "github.com/elastic/elastic-agent-shipper-client/pkg/proto" "github.com/elastic/elastic-agent-shipper-client/pkg/proto/messages" "github.com/elastic/elastic-agent-shipper/queue" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/gofrs/uuid" ) +// Publisher contains all operations required for the shipper server to publish incoming events. +type Publisher interface { + io.Closer + + // AcceptedIndex returns the current sequential index of the accepted events + AcceptedIndex() queue.EntryID + // PersistedIndex returns the current sequential index of the persisted events + PersistedIndex() queue.EntryID + // Publish publishes the given event and returns the current accepted index (after this event) + Publish(*messages.Event) (queue.EntryID, error) +} + +// ShipperServer contains all the gRPC operations for the shipper endpoints. +type ShipperServer interface { + pb.ProducerServer + io.Closer +} + type shipperServer struct { - logger *logp.Logger + logger *logp.Logger + publisher Publisher + + uuid string - queue *queue.Queue + close *sync.Once + ctx context.Context + stop func() pb.UnimplementedProducerServer } -// PublishEvents is the server implementation of the gRPC PublishEvents call -func (serv shipperServer) PublishEvents(_ context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) { - reply := &messages.PublishReply{} - for _, evt := range req.Events { - serv.logger.Infof("Got event: %s", evt.String()) - entryID, err := serv.queue.Publish(evt) - if err != nil { - // If we couldn't accept any events, return the error directly. Otherwise, - // just return success on however many events we were able to handle. - if reply.AcceptedCount == 0 { - return nil, err - } - break +// NewShipperServer creates a new server instance for handling gRPC endpoints. +func NewShipperServer(publisher Publisher) (ShipperServer, error) { + if publisher == nil { + return nil, errors.New("publisher cannot be nil") + } + + id, err := uuid.NewV4() + if err != nil { + return nil, err + } + + s := shipperServer{ + uuid: id.String(), + logger: logp.NewLogger("shipper-server"), + publisher: publisher, + close: &sync.Once{}, + } + + s.ctx, s.stop = context.WithCancel(context.Background()) + + return &s, nil +} + +// GetAcceptedIndex returns the accepted index +func (serv *shipperServer) GetAcceptedIndex() uint64 { + return uint64(serv.publisher.AcceptedIndex()) +} + +// GetPersistedIndex returns the persisted index +func (serv *shipperServer) GetPersistedIndex() uint64 { + return uint64(serv.publisher.PersistedIndex()) +} + +// PublishEvents is the server implementation of the gRPC PublishEvents call. +func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) { + resp := &messages.PublishReply{ + Uuid: serv.uuid, + } + + // the value in the request is optional + if req.Uuid != "" && req.Uuid != serv.uuid { + resp.AcceptedIndex = serv.GetAcceptedIndex() + resp.PersistedIndex = serv.GetPersistedIndex() + serv.logger.Debugf("shipper UUID does not match, all events rejected. Expected = %s, actual = %s", serv.uuid, req.Uuid) + + return resp, status.Error(codes.FailedPrecondition, fmt.Sprintf("UUID does not match. Expected = %s, actual = %s", serv.uuid, req.Uuid)) + } + + for _, e := range req.Events { + _, err := serv.publisher.Publish(e) + if err == nil { + resp.AcceptedCount++ + continue } - reply.AcceptedCount = reply.AcceptedCount + 1 - reply.AcceptedIndex = uint64(entryID) + + if errors.Is(err, queue.ErrQueueIsFull) { + serv.logger.Debugf("queue is full, not all events accepted. Events = %d, accepted = %d", len(req.Events), resp.AcceptedCount) + } else { + err = fmt.Errorf("failed to enqueue an event. Events = %d, accepted = %d: %w", len(req.Events), resp.AcceptedCount, err) + serv.logger.Error(err) + } + + break } - return reply, nil + + resp.AcceptedIndex = serv.GetAcceptedIndex() + resp.PersistedIndex = serv.GetPersistedIndex() + + serv.logger. + Debugf("finished publishing a batch. Events = %d, accepted = %d, accepted index = %d, persisted index = %d", + len(req.Events), + resp.AcceptedCount, + resp.AcceptedIndex, + resp.PersistedIndex, + ) + + return resp, nil } -// StreamAcknowledgements is the server implementation of the gRPC StreamAcknowledgements call -// func (serv shipperServer) StreamAcknowledgements(streamReq *messages.StreamAcksRequest, prd pb.Producer_StreamAcknowledgementsServer) error { +// PublishEvents is the server implementation of the gRPC PersistedIndex call. +func (serv *shipperServer) PersistedIndex(req *messages.PersistedIndexRequest, producer pb.Producer_PersistedIndexServer) error { + serv.logger.Debug("new subscriber for persisted index change") + defer serv.logger.Debug("unsubscribed from persisted index change") + + persistedIndex := serv.GetPersistedIndex() + err := producer.Send(&messages.PersistedIndexReply{ + Uuid: serv.uuid, + PersistedIndex: persistedIndex, + }) + if err != nil { + return err + } + + pollingIntervalDur := req.PollingInterval.AsDuration() -// // we have no outputs now, so just send a single dummy event -// evt := messages.StreamAcksReply{Acks: []*messages.Acknowledgement{{Timestamp: pbts.Now(), EventId: streamReq.Source.GetInputId()}}} -// err := prd.Send(&evt) + if pollingIntervalDur == 0 { + return nil + } + + ticker := time.NewTicker(pollingIntervalDur) + defer ticker.Stop() + + for { + select { + case <-producer.Context().Done(): + return fmt.Errorf("producer context: %w", producer.Context().Err()) + + case <-serv.ctx.Done(): + return fmt.Errorf("server is stopped: %w", serv.ctx.Err()) + + case <-ticker.C: + newPersistedIndex := serv.GetPersistedIndex() + if newPersistedIndex == persistedIndex { + continue + } + persistedIndex = newPersistedIndex + err := producer.Send(&messages.PersistedIndexReply{ + Uuid: serv.uuid, + PersistedIndex: persistedIndex, + }) + if err != nil { + return fmt.Errorf("failed to send the update: %w", err) + } + } + } +} + +// Close implements the Closer interface +func (serv *shipperServer) Close() error { + serv.close.Do(func() { + serv.stop() + }) -// if err != nil { -// return fmt.Errorf("error in StreamAcknowledgements: %w", err) -// } -// return nil -// } + return nil +} diff --git a/server/server_test.go b/server/server_test.go index 93c613e..a21458b 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,107 +6,282 @@ package server import ( "context" - "fmt" - "sync" + "net" + "strings" "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/elastic/elastic-agent-shipper-client/pkg/helpers" + pb "github.com/elastic/elastic-agent-shipper-client/pkg/proto" + "github.com/elastic/elastic-agent-shipper-client/pkg/proto/messages" + "github.com/elastic/elastic-agent-shipper/queue" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - - "github.com/elastic/elastic-agent-client/v7/pkg/client" - "github.com/elastic/elastic-agent-client/v7/pkg/client/mock" - "github.com/elastic/elastic-agent-client/v7/pkg/proto" + "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" ) -func TestAgentControl(t *testing.T) { - unitOneID := mock.NewID() - - token := mock.NewID() - var gotConfig, gotHealthy, gotStopped bool - - var mut sync.Mutex - - t.Logf("Creating mock server") - srv := mock.StubServerV2{ - CheckinV2Impl: func(observed *proto.CheckinObserved) *proto.CheckinExpected { - mut.Lock() - defer mut.Unlock() - if observed.Token == token { - if len(observed.Units) > 0 { - t.Logf("Current unit state is: %v", observed.Units[0].State) - } - - // initial checkin - if len(observed.Units) == 0 || observed.Units[0].State == proto.State_STARTING { - gotConfig = true - return &proto.CheckinExpected{ - Units: []*proto.UnitExpected{ - { - Id: unitOneID, - Type: proto.UnitType_OUTPUT, - ConfigStateIdx: 1, - Config: `{"logging": {"level": "debug"}}`, // hack to make my life easier - State: proto.State_HEALTHY, - }, - }, - } - } else if observed.Units[0].State == proto.State_HEALTHY { - gotHealthy = true - //shutdown - return &proto.CheckinExpected{ - Units: []*proto.UnitExpected{ - { - Id: unitOneID, - Type: proto.UnitType_OUTPUT, - ConfigStateIdx: 1, - Config: "{}", - State: proto.State_STOPPED, - }, - }, - } - } else if observed.Units[0].State == proto.State_STOPPED { - gotStopped = true - // remove the unit? I think? - return &proto.CheckinExpected{ - Units: nil, - } - } - - } - - //gotInvalid = true - return nil - }, - ActionImpl: func(response *proto.ActionResponse) error { +const bufSize = 1024 * 1024 // 1MB + +func TestPublish(t *testing.T) { + ctx := context.Background() - return nil + sampleValues, err := helpers.NewStruct(map[string]interface{}{ + "string": "value", + "number": 42, + }) + + require.NoError(t, err) + + e := &messages.Event{ + Timestamp: timestamppb.Now(), + Source: &messages.Source{ + InputId: "input", + StreamId: "stream", }, - ActionsChan: make(chan *mock.PerformAction, 100), - } // end of srv declaration - - require.NoError(t, srv.Start()) - defer srv.Stop() - - t.Logf("creating client") - // connect with client - validClient := client.NewV2(fmt.Sprintf(":%d", srv.Port), token, client.VersionInfo{ - Name: "program", - Version: "v1.0.0", - Meta: map[string]string{ - "key": "value", + DataStream: &messages.DataStream{ + Type: "log", + Dataset: "default", + Namespace: "default", }, - }, grpc.WithTransportCredentials(insecure.NewCredentials())) + Metadata: sampleValues, + Fields: sampleValues, + } + + publisher := &publisherMock{ + persistedIndex: 42, + } + shipper, err := NewShipperServer(publisher) + defer func() { _ = shipper.Close() }() + require.NoError(t, err) + client, stop := startServer(t, ctx, shipper) + defer stop() + + // get the current UUID + pirCtx, cancel := context.WithCancel(ctx) + consumer, err := client.PersistedIndex(pirCtx, &messages.PersistedIndexRequest{}) + require.NoError(t, err) + pir, err := consumer.Recv() + require.NoError(t, err) + cancel() // close the stream + + t.Run("should successfully publish a batch", func(t *testing.T) { + publisher.q = make([]*messages.Event, 0, 3) + events := []*messages.Event{e, e, e} + reply, err := client.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: pir.Uuid, + Events: events, + }) + require.NoError(t, err) + require.Equal(t, uint32(len(events)), reply.AcceptedCount) + require.Equal(t, uint64(len(events)), reply.AcceptedIndex) + require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + }) + + t.Run("should grow accepted index", func(t *testing.T) { + publisher.q = make([]*messages.Event, 0, 3) + events := []*messages.Event{e} + reply, err := client.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: pir.Uuid, + Events: events, + }) + require.NoError(t, err) + require.Equal(t, uint32(len(events)), reply.AcceptedCount) + require.Equal(t, uint64(1), reply.AcceptedIndex) + require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + reply, err = client.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: pir.Uuid, + Events: events, + }) + require.NoError(t, err) + require.Equal(t, uint32(len(events)), reply.AcceptedCount) + require.Equal(t, uint64(2), reply.AcceptedIndex) + require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + reply, err = client.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: pir.Uuid, + Events: events, + }) + require.NoError(t, err) + require.Equal(t, uint32(len(events)), reply.AcceptedCount) + require.Equal(t, uint64(3), reply.AcceptedIndex) + require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + }) + + t.Run("should return different count when queue is full", func(t *testing.T) { + publisher.q = make([]*messages.Event, 0, 1) + events := []*messages.Event{e, e, e} // 3 should not fit into the queue size 1 + reply, err := client.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: pir.Uuid, + Events: events, + }) + require.NoError(t, err) + require.Equal(t, uint32(1), reply.AcceptedCount) + require.Equal(t, uint64(1), reply.AcceptedIndex) + require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + }) + + t.Run("should return an error when uuid does not match", func(t *testing.T) { + publisher.q = make([]*messages.Event, 0, 3) + events := []*messages.Event{e, e, e} + reply, err := client.PublishEvents(ctx, &messages.PublishRequest{ + Uuid: "wrong", + Events: events, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "UUID does not match") + require.Nil(t, reply) + }) +} + +func TestPersistedIndex(t *testing.T) { + ctx := context.Background() + + publisher := &publisherMock{persistedIndex: 42} + + t.Run("server should send updates to the clients", func(t *testing.T) { + shipper, err := NewShipperServer(publisher) + defer func() { _ = shipper.Close() }() + require.NoError(t, err) + client, stop := startServer(t, ctx, shipper) + defer stop() + + // first delivery can happen before the first index update + require.Eventually(t, func() bool { + cl := createConsumers(t, ctx, client, 5, 5*time.Millisecond) + defer cl.stop() + return cl.assertConsumed(t, 42) // initial value in the publisher + }, 100*time.Millisecond, time.Millisecond, "clients are supposed to get the update") + + cl := createConsumers(t, ctx, client, 50, 5*time.Millisecond) + publisher.persistedIndex = 64 + + cl.assertConsumed(t, 64) + + publisher.persistedIndex = 128 + + cl.assertConsumed(t, 128) + + cl.stop() + }) + + t.Run("server should properly shutdown", func(t *testing.T) { + shipper, err := NewShipperServer(publisher) + require.NoError(t, err) + client, stop := startServer(t, ctx, shipper) + defer stop() + + cl := createConsumers(t, ctx, client, 50, 5*time.Millisecond) + publisher.persistedIndex = 64 + shipper.Close() // stopping the server + require.Eventually(t, func() bool { + return cl.assertClosedServer(t) // initial value in the publisher + }, 100*time.Millisecond, time.Millisecond, "server was supposed to shutdown") + }) +} - t.Logf("starting shipper controller") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - err := runController(ctx, validClient) - assert.NoError(t, err) +func startServer(t *testing.T, ctx context.Context, shipperServer ShipperServer) (client pb.ProducerClient, stop func()) { + lis := bufconn.Listen(bufSize) + grpcServer := grpc.NewServer() + + pb.RegisterProducerServer(grpcServer, shipperServer) + go func() { + _ = grpcServer.Serve(lis) + }() + + bufDialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + conn, err := grpc.DialContext( + ctx, + "bufnet", + grpc.WithContextDialer(bufDialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + require.NoError(t, err) + } + + stop = func() { + shipperServer.Close() + conn.Close() + grpcServer.Stop() + } + + return pb.NewProducerClient(conn), stop +} + +func createConsumers(t *testing.T, ctx context.Context, client pb.ProducerClient, count int, pollingInterval time.Duration) consumerList { + ctx, cancel := context.WithCancel(ctx) + + cl := consumerList{ + stop: cancel, + consumers: make([]pb.Producer_PersistedIndexClient, 0, count), + } + for i := 0; i < 50; i++ { + consumer, err := client.PersistedIndex(ctx, &messages.PersistedIndexRequest{ + PollingInterval: durationpb.New(pollingInterval), + }) + require.NoError(t, err) + cl.consumers = append(cl.consumers, consumer) + } + + return cl +} + +type consumerList struct { + consumers []pb.Producer_PersistedIndexClient + stop func() +} + +func (l consumerList) assertConsumed(t *testing.T, value uint64) bool { + for _, c := range l.consumers { + pir, err := c.Recv() + require.NoError(t, err) + if pir.PersistedIndex != value { + return false + } + } + return true +} + +func (l consumerList) assertClosedServer(t *testing.T) bool { + for _, c := range l.consumers { + _, err := c.Recv() + if err == nil { + return false + } + + if !strings.Contains(err.Error(), "server is stopped: context canceled") { + return false + } + } + + return true +} + +type publisherMock struct { + Publisher + q []*messages.Event + persistedIndex queue.EntryID +} + +func (p *publisherMock) Publish(event *messages.Event) (queue.EntryID, error) { + if len(p.q) == cap(p.q) { + return queue.EntryID(0), queue.ErrQueueIsFull + } + + p.q = append(p.q, event) + return queue.EntryID(len(p.q)), nil +} + +func (p *publisherMock) AcceptedIndex() queue.EntryID { + return queue.EntryID(len(p.q)) +} - assert.True(t, gotConfig, "config state") - assert.True(t, gotHealthy, "healthy state") - assert.True(t, gotStopped, "stopped state") +func (p *publisherMock) PersistedIndex() queue.EntryID { + return p.persistedIndex }