diff --git a/Makefile b/Makefile index 43e82ba99..7cfe9fa2a 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,7 @@ KUSTOMIZE := $(GO) run sigs.k8s.io/kustomize/kustomize/v4@v4.5 SETUP_ENVTEST := $(GO) run sigs.k8s.io/controller-runtime/tools/setup-envtest@v0.0.0-20220304125252-9ee63fc65a97 GOLANGCI_LINT := $(GO) run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.52 YAMLFMT := $(GO) run github.com/google/yamlfmt/cmd/yamlfmt@v0.6 +MOQ := $(GO) run github.com/matryer/moq@v0.3 # Installed tools PROTOC_GEN_GO_GRPC := google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.2 @@ -94,6 +95,11 @@ e2e-test: ## Run e2e tests $(SETUP_ENVTEST) use source <($(SETUP_ENVTEST) use -p env) && $(GO) test -v ./internal/e2e/... -tags=e2e +mocks: + $(MOQ) -fmt goimpots -rm -out ./internal/proto/workflow/v2/mock.go ./internal/proto/workflow/v2 WorkflowServiceClient WorkflowService_StreamWorkflowsClient + $(MOQ) -fmt goimports -rm -out ./internal/agent/transport/mock.go ./internal/agent/transport WorkflowHandler + $(MOQ) -fmt goimports -rm -out ./internal/agent/mock.go ./internal/agent Transport ContainerRuntime + .PHONY: generate-proto generate-proto: buf.gen.yaml buf.lock $(shell git ls-files '**/*.proto') _protoc $(BUF) mod update @@ -242,4 +248,4 @@ yamllint: $(YAMLLINT_BIN) .PHONY: _protoc ## Install all required tools for use with this Makefile. _protoc: GOBIN=$${PWD}/bin $(GO) install $(PROTOC_GEN_GO) - GOBIN=$${PWD}/bin $(GO) install $(PROTOC_GEN_GO_GRPC) + GOBIN=$${PWD}/bin $(GO) install $(PROTOC_GEN_GO_GRPC) \ No newline at end of file diff --git a/go.mod b/go.mod index 23cdfecc9..ae9f1e0c2 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/go-logr/zerologr v1.2.3 github.com/google/go-cmp v0.5.9 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 + github.com/kr/pretty v0.3.1 github.com/onsi/ginkgo/v2 v2.9.4 github.com/onsi/gomega v1.27.6 github.com/opencontainers/image-spec v1.1.0-rc.3 @@ -77,6 +78,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/kr/text v0.2.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.12 // indirect diff --git a/go.sum b/go.sum index 9d128a19f..cf6c9974d 100644 --- a/go.sum +++ b/go.sum @@ -597,6 +597,7 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -690,6 +691,7 @@ github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index c22d1b099..4778adfde 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -3,7 +3,7 @@ package agent import ( "context" "errors" - "fmt" + "sync" "github.com/go-logr/logr" "github.com/tinkerbell/tink/internal/agent/event" @@ -30,6 +30,10 @@ type Agent struct { // sem ensure we handle a single workflow at a time. sem chan struct{} + + // executionContext tracks the currently executing workflow. + executionContext *executionContext + mtx sync.Mutex } // Start finalizes the Agent configuration and starts the configured Transport so it is ready @@ -45,7 +49,12 @@ func (agent *Agent) Start(ctx context.Context) error { } if agent.Runtime == nil { - return errors.New("agent.Runtime must be set before calling Start()") + //nolint:stylecheck // Runtime is a field of agent. + return errors.New("Runtime field must be set before calling Start()") + } + + if agent.Log.GetSink() == nil { + agent.Log = logr.Discard() } agent.Log = agent.Log.WithValues("agent_id", agent.ID) @@ -54,41 +63,79 @@ func (agent *Agent) Start(ctx context.Context) error { agent.sem = make(chan struct{}, 1) agent.sem <- struct{}{} - // Launch the transport ensuring we can recover any errors. - transportErr := make(chan error, 1) - go func() { - agent.Log.Info("Starting agent") - transportErr <- agent.Transport.Start(ctx, agent.ID, agent) - }() - - select { - case err := <-transportErr: - return fmt.Errorf("transport: %w", err) - case <-ctx.Done(): - return ctx.Err() - } + return agent.Transport.Start(ctx, agent.ID, agent) } // HandleWorkflow satisfies transport. -func (agent *Agent) HandleWorkflow(ctx context.Context, wflw workflow.Workflow, events event.Recorder) error { - // sem isn't protected by a synchronization data structure so this is technically invoking - // undefined behavior - consider this a best effort to ensuring Start() has been called. +func (agent *Agent) HandleWorkflow(ctx context.Context, wflw workflow.Workflow, events event.Recorder) { if agent.sem == nil { - return errors.New("agent must have Start() called before calling HandleWorkflow()") + agent.Log.Info("Agent must have Start() called before calling HandleWorkflow()") } select { case <-agent.sem: - // Replenish the semaphore on exit so we can pick up another workflow. - defer func() { agent.sem <- struct{}{} }() - return agent.run(ctx, wflw, events) + // Ensure we configure the current workflow and cancellation func before we launch the + // goroutine to avoid a race with CancelWorkflow. + agent.mtx.Lock() + defer agent.mtx.Unlock() + + ctx, cancel := context.WithCancel(ctx) + agent.executionContext = &executionContext{ + Workflow: wflw, + Cancel: cancel, + } + + go func() { + // Replenish the semaphore on exit so we can pick up another workflow. + defer func() { agent.sem <- struct{}{} }() + + if err := agent.run(ctx, wflw, events); err != nil { + // TODO(chrisdoherty4) An error indicates something catastrophic happened and we need + // to signal termination of the agent. + _ = err + } + + // Nilify the execution context after running so cancellation requests are ignored. + agent.mtx.Lock() + defer agent.mtx.Unlock() + agent.executionContext = nil + }() default: reject := event.WorkflowRejected{ ID: wflw.ID, Message: "workflow already in progress", } - events.RecordEvent(ctx, reject) - return nil + + // TODO(chrisdoherty) Change event recording to return an error because if we can't record + // events we need to exit the agent as somethings wrong. + events.RecordEvent(context.Background(), reject) + } +} + +func (agent *Agent) CancelWorkflow(workflowID string) { + agent.mtx.Lock() + defer agent.mtx.Unlock() + + if agent.executionContext == nil { + agent.Log.Info("No workflow running; ignoring cancellation request", "workflow_id", workflowID) + return } + + if agent.executionContext.Workflow.ID != workflowID { + agent.Log.Info( + "Incorrect workflow ID in cancellation request; ignoring cancellation request", + "workflow_id", workflowID, + "running_workflow_id", agent.executionContext.Workflow.ID, + ) + return + } + + agent.Log.Info("Cancelling workflow", "workflow_id", workflowID) + agent.executionContext.Cancel() +} + +type executionContext struct { + Workflow workflow.Workflow + Cancel context.CancelFunc } diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index 6c412b3d4..93ace0fdf 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -4,6 +4,7 @@ import ( "context" "strings" "testing" + "time" "github.com/go-logr/logr" "github.com/go-logr/zapr" @@ -59,6 +60,14 @@ func TestAgent_InvalidStart(t *testing.T) { Runtime: runtime.Noop(), }, }, + { + Name: "NoLogger", + Agent: agent.Agent{ + ID: "1234", + Transport: transport.Noop(), + Runtime: runtime.Noop(), + }, + }, } for _, tc := range cases { @@ -76,16 +85,13 @@ func TestAgent_InvalidStart(t *testing.T) { } } +// The goal of this test is to ensure the agent rejects concurrent workflows. func TestAgent_ConcurrentWorkflows(t *testing.T) { - // The goal of this test is to ensure the agent rejects concurrent workflows. We have to - // build a valid agent because it will also reject calls to HandleWorkflow without first - // starting the Agent. - logger := zapr.NewLogger(zap.Must(zap.NewDevelopment())) - recorder := event.NoopRecorder() + trnport := transport.Noop() - wrkflow := workflow.Workflow{ + wflw := workflow.Workflow{ ID: "1234", Actions: []workflow.Action{ { @@ -96,14 +102,8 @@ func TestAgent_ConcurrentWorkflows(t *testing.T) { }, } - trnport := agent.TransportMock{ - StartFunc: func(ctx context.Context, agentID string, handler workflow.Handler) error { - return handler.HandleWorkflow(ctx, wrkflow, recorder) - }, - } - + // Started is used to indicate the runtime has received the workflow. started := make(chan struct{}) - rntime := agent.ContainerRuntimeMock{ RunFunc: func(ctx context.Context, action workflow.Action) error { started <- struct{}{} @@ -119,26 +119,28 @@ func TestAgent_ConcurrentWorkflows(t *testing.T) { ID: "1234", } - // Build a cancellable context so we can tear the goroutine down. + // HandleWorkflow will reject us if we haven't started the agent first. + if err := agnt.Start(context.Background()); err != nil { + t.Fatal(err) + } - errs := make(chan error) - ctx, cancel := context.WithCancel(context.Background()) + // Build a cancellable context so we can tear everything down. The timeout is guesswork but + // this test shouldn't take long. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - go func() { errs <- agnt.Start(ctx) }() + // Handle the first workflow and wait for it to start. + agnt.HandleWorkflow(ctx, wflw, recorder) - // Await either an error or the mock runtime to tell us its stated. + // Wait for the container runtime to start. select { - case err := <-errs: - t.Fatalf("Unexpected error: %v", err) case <-started: + case <-ctx.Done(): + t.Fatal(ctx.Err()) } - // Attempt to fire off another workflow. - err := agnt.HandleWorkflow(context.Background(), wrkflow, recorder) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + // Attempt to fire off a second workflow. + agnt.HandleWorkflow(ctx, wflw, recorder) // Ensure the latest event recorded is a event.WorkflowRejected. calls := recorder.RecordEventCalls() @@ -153,7 +155,7 @@ func TestAgent_ConcurrentWorkflows(t *testing.T) { } expectEvent := event.WorkflowRejected{ - ID: wrkflow.ID, + ID: wflw.ID, Message: "workflow already in progress", } if !cmp.Equal(expectEvent, ev) { @@ -407,13 +409,7 @@ message`, for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { - recorder := event.NoopRecorder() - - trnport := agent.TransportMock{ - StartFunc: func(ctx context.Context, agentID string, handler workflow.Handler) error { - return handler.HandleWorkflow(ctx, tc.Workflow, recorder) - }, - } + trnport := transport.Noop() rntime := agent.ContainerRuntimeMock{ RunFunc: func(ctx context.Context, action workflow.Action) error { @@ -424,18 +420,43 @@ message`, }, } + // The event recorder is what tells us the workflow has finished executing so we use it + // to check for the last expected action. + lastEventReceived := make(chan struct{}) + recorder := event.RecorderMock{ + RecordEventFunc: func(contextMoqParam context.Context, event event.Event) { + if cmp.Equal(event, tc.Events[len(tc.Events)-1]) { + lastEventReceived <- struct{}{} + } + }, + } + + // Create and start the agent as start is a prereq to calling HandleWorkflow(). agnt := agent.Agent{ Log: logger, Transport: &trnport, Runtime: &rntime, ID: "1234", } - - err := agnt.Start(context.Background()) - if err != nil { + if err := agnt.Start(context.Background()); err != nil { t.Fatalf("Unexpected error: %v", err) } + // Configure a timeout of 5 seconds, this test shouldn't take long. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Handle the workflow + agnt.HandleWorkflow(ctx, tc.Workflow, &recorder) + + // Wait for the last expected event or timeout. + select { + case <-lastEventReceived: + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + + // Validate all events received are what we expected. calls := recorder.RecordEventCalls() if len(calls) != len(tc.Events) { t.Fatalf("Expected %v events; Received %v\n%+v", len(tc.Events), len(calls), calls) diff --git a/internal/agent/mocks.go b/internal/agent/mock.go similarity index 88% rename from internal/agent/mocks.go rename to internal/agent/mock.go index f7b8ddb6d..c7d50b390 100644 --- a/internal/agent/mocks.go +++ b/internal/agent/mock.go @@ -5,8 +5,10 @@ package agent import ( "context" - "github.com/tinkerbell/tink/internal/agent/workflow" "sync" + + "github.com/tinkerbell/tink/internal/agent/transport" + "github.com/tinkerbell/tink/internal/agent/workflow" ) // Ensure, that TransportMock does implement Transport. @@ -19,7 +21,7 @@ var _ Transport = &TransportMock{} // // // make and configure a mocked Transport // mockedTransport := &TransportMock{ -// StartFunc: func(contextMoqParam context.Context, agentID string, handler workflow.Handler) error { +// StartFunc: func(contextMoqParam context.Context, agentID string, workflowHandler transport.WorkflowHandler) error { // panic("mock out the Start method") // }, // } @@ -30,7 +32,7 @@ var _ Transport = &TransportMock{} // } type TransportMock struct { // StartFunc mocks the Start method. - StartFunc func(contextMoqParam context.Context, agentID string, handler workflow.Handler) error + StartFunc func(contextMoqParam context.Context, agentID string, workflowHandler transport.WorkflowHandler) error // calls tracks calls to the methods. calls struct { @@ -40,31 +42,31 @@ type TransportMock struct { ContextMoqParam context.Context // AgentID is the agentID argument value. AgentID string - // Handler is the handler argument value. - Handler workflow.Handler + // WorkflowHandler is the workflowHandler argument value. + WorkflowHandler transport.WorkflowHandler } } lockStart sync.RWMutex } // Start calls StartFunc. -func (mock *TransportMock) Start(contextMoqParam context.Context, agentID string, handler workflow.Handler) error { +func (mock *TransportMock) Start(contextMoqParam context.Context, agentID string, workflowHandler transport.WorkflowHandler) error { if mock.StartFunc == nil { panic("TransportMock.StartFunc: method is nil but Transport.Start was just called") } callInfo := struct { ContextMoqParam context.Context AgentID string - Handler workflow.Handler + WorkflowHandler transport.WorkflowHandler }{ ContextMoqParam: contextMoqParam, AgentID: agentID, - Handler: handler, + WorkflowHandler: workflowHandler, } mock.lockStart.Lock() mock.calls.Start = append(mock.calls.Start, callInfo) mock.lockStart.Unlock() - return mock.StartFunc(contextMoqParam, agentID, handler) + return mock.StartFunc(contextMoqParam, agentID, workflowHandler) } // StartCalls gets all the calls that were made to Start. @@ -74,12 +76,12 @@ func (mock *TransportMock) Start(contextMoqParam context.Context, agentID string func (mock *TransportMock) StartCalls() []struct { ContextMoqParam context.Context AgentID string - Handler workflow.Handler + WorkflowHandler transport.WorkflowHandler } { var calls []struct { ContextMoqParam context.Context AgentID string - Handler workflow.Handler + WorkflowHandler transport.WorkflowHandler } mock.lockStart.RLock() calls = mock.calls.Start diff --git a/internal/agent/transport.go b/internal/agent/transport.go index 6aaf616a2..bf204f367 100644 --- a/internal/agent/transport.go +++ b/internal/agent/transport.go @@ -3,7 +3,7 @@ package agent import ( "context" - "github.com/tinkerbell/tink/internal/agent/workflow" + "github.com/tinkerbell/tink/internal/agent/transport" ) // Transport is a transport mechanism for communicating workflows to the agent. @@ -11,5 +11,5 @@ type Transport interface { // Start is a blocking call that starts the transport and begins retrieving workflows for the // given agentID. The transport should pass workflows to the Handler. The transport // should block until its told to cancel via the context. - Start(_ context.Context, agentID string, _ workflow.Handler) error + Start(_ context.Context, agentID string, _ transport.WorkflowHandler) error } diff --git a/internal/agent/transport/fake.go b/internal/agent/transport/fake.go index 1379c3b21..d9a7830b1 100644 --- a/internal/agent/transport/fake.go +++ b/internal/agent/transport/fake.go @@ -19,12 +19,10 @@ type Fake struct { Workflows []workflow.Workflow } -func (f Fake) Start(ctx context.Context, _ string, runner workflow.Handler) error { +func (f Fake) Start(ctx context.Context, _ string, handler WorkflowHandler) error { f.Log.Info("Starting fake transport") for _, w := range f.Workflows { - if err := runner.HandleWorkflow(ctx, w, f); err != nil { - f.Log.Error(err, "Running workflow", "workflow", w) - } + handler.HandleWorkflow(ctx, w, f) } return nil } diff --git a/internal/agent/transport/grpc.go b/internal/agent/transport/grpc.go index c2bc8db49..0fd5653a8 100644 --- a/internal/agent/transport/grpc.go +++ b/internal/agent/transport/grpc.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "sync" "github.com/go-logr/logr" "github.com/tinkerbell/tink/internal/agent/event" @@ -36,13 +35,11 @@ func (g *GRPC) Start(ctx context.Context, agentID string, handler WorkflowHandle } log := g.log - var idx workflowIndex for { request, err := stream.Recv() switch { case errors.Is(err, io.EOF): - // TODO(chrisdoherty4) Think about cancelling return nil case err != nil: return err @@ -57,31 +54,10 @@ func (g *GRPC) Start(ctx context.Context, agentID string, handler WorkflowHandle continue } - wflw := toWorkflow(grpcWorkflow) - - // Start a new execution context so we can cancel it as needed. - ctx, err := idx.Insert(stream.Context(), wflw.ID) - if err != nil { - // Handle already excuting workflow. Perhaps this needs to be an agent concern - // so that multiple transports benefit from the same handling. Or, given its - // already running, perhaps we just log we were asked to run the same workflow - // twice. - _ = err - } - - go func(ctx context.Context, wflw workflow.Workflow) { - if err := handler.HandleWorkflow(ctx, wflw, g); err != nil { - log.Info("Failed to handle workflow", "error", err) - } - - // Stop the execution context so we're no longer tracking the workflow. - idx.Cancel(wflw.ID) - }(ctx, wflw) + handler.HandleWorkflow(ctx, toWorkflow(grpcWorkflow), g) case *workflowproto.StreamWorkflowsResponse_StopWorkflow_: - req := request.GetStopWorkflow() - // TODO: Validate workflow ID - idx.Cancel(req.WorkflowId) + handler.CancelWorkflow(request.GetStopWorkflow().WorkflowId) } } } @@ -186,41 +162,3 @@ func toGRPC(e event.Event) (*workflowproto.Event, error) { Event: e, }) } - -type workflowIndex struct { - cancellers map[string]context.CancelFunc - mtx sync.Mutex -} - -func (c *workflowIndex) Insert(ctx context.Context, id string) (context.Context, error) { - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.cancellers == nil { - c.cancellers = map[string]context.CancelFunc{} - } - - if _, ok := c.cancellers[id]; ok { - return nil, fmt.Errorf("workflow is already tracked (%v)", id) - } - - // Create a new cancellation function and add it to the c - ctx, cancel := context.WithCancel(ctx) - c.cancellers[id] = cancel - return ctx, nil -} - -func (c *workflowIndex) Cancel(id string) { - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.cancellers == nil { - return - } - - if cancel, ok := c.cancellers[id]; ok { - cancel() - } - - delete(c.cancellers, id) -} diff --git a/internal/agent/transport/grpc_test.go b/internal/agent/transport/grpc_test.go new file mode 100644 index 000000000..c24215481 --- /dev/null +++ b/internal/agent/transport/grpc_test.go @@ -0,0 +1,76 @@ +package transport_test + +import ( + "context" + "fmt" + "io" + "sync" + "testing" + + "github.com/go-logr/zerologr" + "github.com/kr/pretty" + "github.com/rs/zerolog" + "github.com/tinkerbell/tink/internal/agent/event" + "github.com/tinkerbell/tink/internal/agent/transport" + "github.com/tinkerbell/tink/internal/agent/workflow" + workflowproto "github.com/tinkerbell/tink/internal/proto/workflow/v2" + "google.golang.org/grpc" +) + +func TestGRPC(t *testing.T) { + logger := zerolog.New(zerolog.NewConsoleWriter()) + type streamResponse struct { + Workflow *workflowproto.StreamWorkflowsResponse + Error error + } + responses := make(chan streamResponse, 2) + responses <- streamResponse{ + Workflow: &workflowproto.StreamWorkflowsResponse{ + Cmd: &workflowproto.StreamWorkflowsResponse_StartWorkflow_{ + StartWorkflow: &workflowproto.StreamWorkflowsResponse_StartWorkflow{ + Workflow: &workflowproto.Workflow{}, + }, + }, + }, + } + responses <- streamResponse{ + Error: io.EOF, + } + + stream := &workflowproto.WorkflowService_StreamWorkflowsClientMock{ + RecvFunc: func() (*workflowproto.StreamWorkflowsResponse, error) { + r, ok := <-responses + if !ok { + return nil, io.EOF + } + return r.Workflow, r.Error + }, + ContextFunc: context.Background, + } + client := &workflowproto.WorkflowServiceClientMock{ + StreamWorkflowsFunc: func(ctx context.Context, in *workflowproto.StreamWorkflowsRequest, opts ...grpc.CallOption) (workflowproto.WorkflowService_StreamWorkflowsClient, error) { + return stream, nil + }, + } + + var wg sync.WaitGroup + wg.Add(1) + handler := &transport.WorkflowHandlerMock{ + HandleWorkflowFunc: func(contextMoqParam context.Context, workflow workflow.Workflow, recorder event.Recorder) { + defer wg.Done() + fmt.Println("hanlding") + close(responses) + }, + } + + g := transport.NewGRPC(zerologr.New(&logger), client) + + err := g.Start(context.Background(), "id", handler) + if err != nil { + t.Fatal(err) + } + + wg.Wait() + + pretty.Println(handler.HandleWorkflowCalls()) +} diff --git a/internal/agent/transport/handler.go b/internal/agent/transport/handler.go index c8da2a9e3..f483c3493 100644 --- a/internal/agent/transport/handler.go +++ b/internal/agent/transport/handler.go @@ -7,10 +7,18 @@ import ( "github.com/tinkerbell/tink/internal/agent/workflow" ) +// Change everything so we can launch and cancel workflows with APIs instead of contexts. +// If we find there's a catastrophic error in the agent, just have the agent tear down the transport +// rather than trying to signal the transport to stop. + // WorkflowHandler is responsible for handling workflow execution. type WorkflowHandler interface { - // HandleWorkflow begins executing the given workflow. The event recorder can be used to - // indicate the progress of a workflow. If the given context becomes cancelled, the workflow - // handler should stop workflow execution. - HandleWorkflow(context.Context, workflow.Workflow, event.Recorder) error + // HandleWorkflow executes the given workflow. The event.Recorder can be used to publish events + // as the workflow transits its lifecycle. HandleWorkflow should not block and should be efficient + // in handing off workflow processing. + HandleWorkflow(context.Context, workflow.Workflow, event.Recorder) + + // CancelWorkflow cancels a workflow identified by workflowID. It should not block and should + // be efficient in handing off the cancellation request. + CancelWorkflow(workflowID string) } diff --git a/internal/agent/transport/mock.go b/internal/agent/transport/mock.go new file mode 100644 index 000000000..e73ba35f4 --- /dev/null +++ b/internal/agent/transport/mock.go @@ -0,0 +1,134 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package transport + +import ( + "context" + "sync" + + "github.com/tinkerbell/tink/internal/agent/event" + "github.com/tinkerbell/tink/internal/agent/workflow" +) + +// Ensure, that WorkflowHandlerMock does implement WorkflowHandler. +// If this is not the case, regenerate this file with moq. +var _ WorkflowHandler = &WorkflowHandlerMock{} + +// WorkflowHandlerMock is a mock implementation of WorkflowHandler. +// +// func TestSomethingThatUsesWorkflowHandler(t *testing.T) { +// +// // make and configure a mocked WorkflowHandler +// mockedWorkflowHandler := &WorkflowHandlerMock{ +// CancelWorkflowFunc: func(workflowID string) { +// panic("mock out the CancelWorkflow method") +// }, +// HandleWorkflowFunc: func(contextMoqParam context.Context, workflowMoqParam workflow.Workflow, recorder event.Recorder) { +// panic("mock out the HandleWorkflow method") +// }, +// } +// +// // use mockedWorkflowHandler in code that requires WorkflowHandler +// // and then make assertions. +// +// } +type WorkflowHandlerMock struct { + // CancelWorkflowFunc mocks the CancelWorkflow method. + CancelWorkflowFunc func(workflowID string) + + // HandleWorkflowFunc mocks the HandleWorkflow method. + HandleWorkflowFunc func(contextMoqParam context.Context, workflowMoqParam workflow.Workflow, recorder event.Recorder) + + // calls tracks calls to the methods. + calls struct { + // CancelWorkflow holds details about calls to the CancelWorkflow method. + CancelWorkflow []struct { + // WorkflowID is the workflowID argument value. + WorkflowID string + } + // HandleWorkflow holds details about calls to the HandleWorkflow method. + HandleWorkflow []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // WorkflowMoqParam is the workflowMoqParam argument value. + WorkflowMoqParam workflow.Workflow + // Recorder is the recorder argument value. + Recorder event.Recorder + } + } + lockCancelWorkflow sync.RWMutex + lockHandleWorkflow sync.RWMutex +} + +// CancelWorkflow calls CancelWorkflowFunc. +func (mock *WorkflowHandlerMock) CancelWorkflow(workflowID string) { + if mock.CancelWorkflowFunc == nil { + panic("WorkflowHandlerMock.CancelWorkflowFunc: method is nil but WorkflowHandler.CancelWorkflow was just called") + } + callInfo := struct { + WorkflowID string + }{ + WorkflowID: workflowID, + } + mock.lockCancelWorkflow.Lock() + mock.calls.CancelWorkflow = append(mock.calls.CancelWorkflow, callInfo) + mock.lockCancelWorkflow.Unlock() + mock.CancelWorkflowFunc(workflowID) +} + +// CancelWorkflowCalls gets all the calls that were made to CancelWorkflow. +// Check the length with: +// +// len(mockedWorkflowHandler.CancelWorkflowCalls()) +func (mock *WorkflowHandlerMock) CancelWorkflowCalls() []struct { + WorkflowID string +} { + var calls []struct { + WorkflowID string + } + mock.lockCancelWorkflow.RLock() + calls = mock.calls.CancelWorkflow + mock.lockCancelWorkflow.RUnlock() + return calls +} + +// HandleWorkflow calls HandleWorkflowFunc. +func (mock *WorkflowHandlerMock) HandleWorkflow(contextMoqParam context.Context, workflowMoqParam workflow.Workflow, recorder event.Recorder) { + if mock.HandleWorkflowFunc == nil { + panic("WorkflowHandlerMock.HandleWorkflowFunc: method is nil but WorkflowHandler.HandleWorkflow was just called") + } + callInfo := struct { + ContextMoqParam context.Context + WorkflowMoqParam workflow.Workflow + Recorder event.Recorder + }{ + ContextMoqParam: contextMoqParam, + WorkflowMoqParam: workflowMoqParam, + Recorder: recorder, + } + mock.lockHandleWorkflow.Lock() + mock.calls.HandleWorkflow = append(mock.calls.HandleWorkflow, callInfo) + mock.lockHandleWorkflow.Unlock() + mock.HandleWorkflowFunc(contextMoqParam, workflowMoqParam, recorder) +} + +// HandleWorkflowCalls gets all the calls that were made to HandleWorkflow. +// Check the length with: +// +// len(mockedWorkflowHandler.HandleWorkflowCalls()) +func (mock *WorkflowHandlerMock) HandleWorkflowCalls() []struct { + ContextMoqParam context.Context + WorkflowMoqParam workflow.Workflow + Recorder event.Recorder +} { + var calls []struct { + ContextMoqParam context.Context + WorkflowMoqParam workflow.Workflow + Recorder event.Recorder + } + mock.lockHandleWorkflow.RLock() + calls = mock.calls.HandleWorkflow + mock.lockHandleWorkflow.RUnlock() + return calls +} diff --git a/internal/agent/workflow/handler.go b/internal/agent/workflow/handler.go deleted file mode 100644 index e63f7c8e6..000000000 --- a/internal/agent/workflow/handler.go +++ /dev/null @@ -1,15 +0,0 @@ -package workflow - -import ( - "context" - - "github.com/tinkerbell/tink/internal/agent/event" -) - -// Handler is responsible for handling workflow execution. -type Handler interface { - // HandleWorkflow begins executing the given workflow. The event recorder can be used to - // indicate the progress of a workflow. If the given context becomes cancelled, the workflow - // handler should stop workflow execution. - HandleWorkflow(context.Context, Workflow, event.Recorder) error -} diff --git a/internal/proto/workflow/v2/mock.go b/internal/proto/workflow/v2/mock.go new file mode 100644 index 000000000..7641ec44d --- /dev/null +++ b/internal/proto/workflow/v2/mock.go @@ -0,0 +1,440 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package workflow + +import ( + context "context" + grpc "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + sync "sync" +) + +// Ensure, that WorkflowServiceClientMock does implement WorkflowServiceClient. +// If this is not the case, regenerate this file with moq. +var _ WorkflowServiceClient = &WorkflowServiceClientMock{} + +// WorkflowServiceClientMock is a mock implementation of WorkflowServiceClient. +// +// func TestSomethingThatUsesWorkflowServiceClient(t *testing.T) { +// +// // make and configure a mocked WorkflowServiceClient +// mockedWorkflowServiceClient := &WorkflowServiceClientMock{ +// PublishEventFunc: func(ctx context.Context, in *PublishEventRequest, opts ...grpc.CallOption) (*PublishEventResponse, error) { +// panic("mock out the PublishEvent method") +// }, +// StreamWorkflowsFunc: func(ctx context.Context, in *StreamWorkflowsRequest, opts ...grpc.CallOption) (WorkflowService_StreamWorkflowsClient, error) { +// panic("mock out the StreamWorkflows method") +// }, +// } +// +// // use mockedWorkflowServiceClient in code that requires WorkflowServiceClient +// // and then make assertions. +// +// } +type WorkflowServiceClientMock struct { + // PublishEventFunc mocks the PublishEvent method. + PublishEventFunc func(ctx context.Context, in *PublishEventRequest, opts ...grpc.CallOption) (*PublishEventResponse, error) + + // StreamWorkflowsFunc mocks the StreamWorkflows method. + StreamWorkflowsFunc func(ctx context.Context, in *StreamWorkflowsRequest, opts ...grpc.CallOption) (WorkflowService_StreamWorkflowsClient, error) + + // calls tracks calls to the methods. + calls struct { + // PublishEvent holds details about calls to the PublishEvent method. + PublishEvent []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // In is the in argument value. + In *PublishEventRequest + // Opts is the opts argument value. + Opts []grpc.CallOption + } + // StreamWorkflows holds details about calls to the StreamWorkflows method. + StreamWorkflows []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // In is the in argument value. + In *StreamWorkflowsRequest + // Opts is the opts argument value. + Opts []grpc.CallOption + } + } + lockPublishEvent sync.RWMutex + lockStreamWorkflows sync.RWMutex +} + +// PublishEvent calls PublishEventFunc. +func (mock *WorkflowServiceClientMock) PublishEvent(ctx context.Context, in *PublishEventRequest, opts ...grpc.CallOption) (*PublishEventResponse, error) { + if mock.PublishEventFunc == nil { + panic("WorkflowServiceClientMock.PublishEventFunc: method is nil but WorkflowServiceClient.PublishEvent was just called") + } + callInfo := struct { + Ctx context.Context + In *PublishEventRequest + Opts []grpc.CallOption + }{ + Ctx: ctx, + In: in, + Opts: opts, + } + mock.lockPublishEvent.Lock() + mock.calls.PublishEvent = append(mock.calls.PublishEvent, callInfo) + mock.lockPublishEvent.Unlock() + return mock.PublishEventFunc(ctx, in, opts...) +} + +// PublishEventCalls gets all the calls that were made to PublishEvent. +// Check the length with: +// +// len(mockedWorkflowServiceClient.PublishEventCalls()) +func (mock *WorkflowServiceClientMock) PublishEventCalls() []struct { + Ctx context.Context + In *PublishEventRequest + Opts []grpc.CallOption +} { + var calls []struct { + Ctx context.Context + In *PublishEventRequest + Opts []grpc.CallOption + } + mock.lockPublishEvent.RLock() + calls = mock.calls.PublishEvent + mock.lockPublishEvent.RUnlock() + return calls +} + +// StreamWorkflows calls StreamWorkflowsFunc. +func (mock *WorkflowServiceClientMock) StreamWorkflows(ctx context.Context, in *StreamWorkflowsRequest, opts ...grpc.CallOption) (WorkflowService_StreamWorkflowsClient, error) { + if mock.StreamWorkflowsFunc == nil { + panic("WorkflowServiceClientMock.StreamWorkflowsFunc: method is nil but WorkflowServiceClient.StreamWorkflows was just called") + } + callInfo := struct { + Ctx context.Context + In *StreamWorkflowsRequest + Opts []grpc.CallOption + }{ + Ctx: ctx, + In: in, + Opts: opts, + } + mock.lockStreamWorkflows.Lock() + mock.calls.StreamWorkflows = append(mock.calls.StreamWorkflows, callInfo) + mock.lockStreamWorkflows.Unlock() + return mock.StreamWorkflowsFunc(ctx, in, opts...) +} + +// StreamWorkflowsCalls gets all the calls that were made to StreamWorkflows. +// Check the length with: +// +// len(mockedWorkflowServiceClient.StreamWorkflowsCalls()) +func (mock *WorkflowServiceClientMock) StreamWorkflowsCalls() []struct { + Ctx context.Context + In *StreamWorkflowsRequest + Opts []grpc.CallOption +} { + var calls []struct { + Ctx context.Context + In *StreamWorkflowsRequest + Opts []grpc.CallOption + } + mock.lockStreamWorkflows.RLock() + calls = mock.calls.StreamWorkflows + mock.lockStreamWorkflows.RUnlock() + return calls +} + +// Ensure, that WorkflowService_StreamWorkflowsClientMock does implement WorkflowService_StreamWorkflowsClient. +// If this is not the case, regenerate this file with moq. +var _ WorkflowService_StreamWorkflowsClient = &WorkflowService_StreamWorkflowsClientMock{} + +// WorkflowService_StreamWorkflowsClientMock is a mock implementation of WorkflowService_StreamWorkflowsClient. +// +// func TestSomethingThatUsesWorkflowService_StreamWorkflowsClient(t *testing.T) { +// +// // make and configure a mocked WorkflowService_StreamWorkflowsClient +// mockedWorkflowService_StreamWorkflowsClient := &WorkflowService_StreamWorkflowsClientMock{ +// CloseSendFunc: func() error { +// panic("mock out the CloseSend method") +// }, +// ContextFunc: func() context.Context { +// panic("mock out the Context method") +// }, +// HeaderFunc: func() (metadata.MD, error) { +// panic("mock out the Header method") +// }, +// RecvFunc: func() (*StreamWorkflowsResponse, error) { +// panic("mock out the Recv method") +// }, +// RecvMsgFunc: func(m interface{}) error { +// panic("mock out the RecvMsg method") +// }, +// SendMsgFunc: func(m interface{}) error { +// panic("mock out the SendMsg method") +// }, +// TrailerFunc: func() metadata.MD { +// panic("mock out the Trailer method") +// }, +// } +// +// // use mockedWorkflowService_StreamWorkflowsClient in code that requires WorkflowService_StreamWorkflowsClient +// // and then make assertions. +// +// } +type WorkflowService_StreamWorkflowsClientMock struct { + // CloseSendFunc mocks the CloseSend method. + CloseSendFunc func() error + + // ContextFunc mocks the Context method. + ContextFunc func() context.Context + + // HeaderFunc mocks the Header method. + HeaderFunc func() (metadata.MD, error) + + // RecvFunc mocks the Recv method. + RecvFunc func() (*StreamWorkflowsResponse, error) + + // RecvMsgFunc mocks the RecvMsg method. + RecvMsgFunc func(m interface{}) error + + // SendMsgFunc mocks the SendMsg method. + SendMsgFunc func(m interface{}) error + + // TrailerFunc mocks the Trailer method. + TrailerFunc func() metadata.MD + + // calls tracks calls to the methods. + calls struct { + // CloseSend holds details about calls to the CloseSend method. + CloseSend []struct { + } + // Context holds details about calls to the Context method. + Context []struct { + } + // Header holds details about calls to the Header method. + Header []struct { + } + // Recv holds details about calls to the Recv method. + Recv []struct { + } + // RecvMsg holds details about calls to the RecvMsg method. + RecvMsg []struct { + // M is the m argument value. + M interface{} + } + // SendMsg holds details about calls to the SendMsg method. + SendMsg []struct { + // M is the m argument value. + M interface{} + } + // Trailer holds details about calls to the Trailer method. + Trailer []struct { + } + } + lockCloseSend sync.RWMutex + lockContext sync.RWMutex + lockHeader sync.RWMutex + lockRecv sync.RWMutex + lockRecvMsg sync.RWMutex + lockSendMsg sync.RWMutex + lockTrailer sync.RWMutex +} + +// CloseSend calls CloseSendFunc. +func (mock *WorkflowService_StreamWorkflowsClientMock) CloseSend() error { + if mock.CloseSendFunc == nil { + panic("WorkflowService_StreamWorkflowsClientMock.CloseSendFunc: method is nil but WorkflowService_StreamWorkflowsClient.CloseSend was just called") + } + callInfo := struct { + }{} + mock.lockCloseSend.Lock() + mock.calls.CloseSend = append(mock.calls.CloseSend, callInfo) + mock.lockCloseSend.Unlock() + return mock.CloseSendFunc() +} + +// CloseSendCalls gets all the calls that were made to CloseSend. +// Check the length with: +// +// len(mockedWorkflowService_StreamWorkflowsClient.CloseSendCalls()) +func (mock *WorkflowService_StreamWorkflowsClientMock) CloseSendCalls() []struct { +} { + var calls []struct { + } + mock.lockCloseSend.RLock() + calls = mock.calls.CloseSend + mock.lockCloseSend.RUnlock() + return calls +} + +// Context calls ContextFunc. +func (mock *WorkflowService_StreamWorkflowsClientMock) Context() context.Context { + if mock.ContextFunc == nil { + panic("WorkflowService_StreamWorkflowsClientMock.ContextFunc: method is nil but WorkflowService_StreamWorkflowsClient.Context was just called") + } + callInfo := struct { + }{} + mock.lockContext.Lock() + mock.calls.Context = append(mock.calls.Context, callInfo) + mock.lockContext.Unlock() + return mock.ContextFunc() +} + +// ContextCalls gets all the calls that were made to Context. +// Check the length with: +// +// len(mockedWorkflowService_StreamWorkflowsClient.ContextCalls()) +func (mock *WorkflowService_StreamWorkflowsClientMock) ContextCalls() []struct { +} { + var calls []struct { + } + mock.lockContext.RLock() + calls = mock.calls.Context + mock.lockContext.RUnlock() + return calls +} + +// Header calls HeaderFunc. +func (mock *WorkflowService_StreamWorkflowsClientMock) Header() (metadata.MD, error) { + if mock.HeaderFunc == nil { + panic("WorkflowService_StreamWorkflowsClientMock.HeaderFunc: method is nil but WorkflowService_StreamWorkflowsClient.Header was just called") + } + callInfo := struct { + }{} + mock.lockHeader.Lock() + mock.calls.Header = append(mock.calls.Header, callInfo) + mock.lockHeader.Unlock() + return mock.HeaderFunc() +} + +// HeaderCalls gets all the calls that were made to Header. +// Check the length with: +// +// len(mockedWorkflowService_StreamWorkflowsClient.HeaderCalls()) +func (mock *WorkflowService_StreamWorkflowsClientMock) HeaderCalls() []struct { +} { + var calls []struct { + } + mock.lockHeader.RLock() + calls = mock.calls.Header + mock.lockHeader.RUnlock() + return calls +} + +// Recv calls RecvFunc. +func (mock *WorkflowService_StreamWorkflowsClientMock) Recv() (*StreamWorkflowsResponse, error) { + if mock.RecvFunc == nil { + panic("WorkflowService_StreamWorkflowsClientMock.RecvFunc: method is nil but WorkflowService_StreamWorkflowsClient.Recv was just called") + } + callInfo := struct { + }{} + mock.lockRecv.Lock() + mock.calls.Recv = append(mock.calls.Recv, callInfo) + mock.lockRecv.Unlock() + return mock.RecvFunc() +} + +// RecvCalls gets all the calls that were made to Recv. +// Check the length with: +// +// len(mockedWorkflowService_StreamWorkflowsClient.RecvCalls()) +func (mock *WorkflowService_StreamWorkflowsClientMock) RecvCalls() []struct { +} { + var calls []struct { + } + mock.lockRecv.RLock() + calls = mock.calls.Recv + mock.lockRecv.RUnlock() + return calls +} + +// RecvMsg calls RecvMsgFunc. +func (mock *WorkflowService_StreamWorkflowsClientMock) RecvMsg(m interface{}) error { + if mock.RecvMsgFunc == nil { + panic("WorkflowService_StreamWorkflowsClientMock.RecvMsgFunc: method is nil but WorkflowService_StreamWorkflowsClient.RecvMsg was just called") + } + callInfo := struct { + M interface{} + }{ + M: m, + } + mock.lockRecvMsg.Lock() + mock.calls.RecvMsg = append(mock.calls.RecvMsg, callInfo) + mock.lockRecvMsg.Unlock() + return mock.RecvMsgFunc(m) +} + +// RecvMsgCalls gets all the calls that were made to RecvMsg. +// Check the length with: +// +// len(mockedWorkflowService_StreamWorkflowsClient.RecvMsgCalls()) +func (mock *WorkflowService_StreamWorkflowsClientMock) RecvMsgCalls() []struct { + M interface{} +} { + var calls []struct { + M interface{} + } + mock.lockRecvMsg.RLock() + calls = mock.calls.RecvMsg + mock.lockRecvMsg.RUnlock() + return calls +} + +// SendMsg calls SendMsgFunc. +func (mock *WorkflowService_StreamWorkflowsClientMock) SendMsg(m interface{}) error { + if mock.SendMsgFunc == nil { + panic("WorkflowService_StreamWorkflowsClientMock.SendMsgFunc: method is nil but WorkflowService_StreamWorkflowsClient.SendMsg was just called") + } + callInfo := struct { + M interface{} + }{ + M: m, + } + mock.lockSendMsg.Lock() + mock.calls.SendMsg = append(mock.calls.SendMsg, callInfo) + mock.lockSendMsg.Unlock() + return mock.SendMsgFunc(m) +} + +// SendMsgCalls gets all the calls that were made to SendMsg. +// Check the length with: +// +// len(mockedWorkflowService_StreamWorkflowsClient.SendMsgCalls()) +func (mock *WorkflowService_StreamWorkflowsClientMock) SendMsgCalls() []struct { + M interface{} +} { + var calls []struct { + M interface{} + } + mock.lockSendMsg.RLock() + calls = mock.calls.SendMsg + mock.lockSendMsg.RUnlock() + return calls +} + +// Trailer calls TrailerFunc. +func (mock *WorkflowService_StreamWorkflowsClientMock) Trailer() metadata.MD { + if mock.TrailerFunc == nil { + panic("WorkflowService_StreamWorkflowsClientMock.TrailerFunc: method is nil but WorkflowService_StreamWorkflowsClient.Trailer was just called") + } + callInfo := struct { + }{} + mock.lockTrailer.Lock() + mock.calls.Trailer = append(mock.calls.Trailer, callInfo) + mock.lockTrailer.Unlock() + return mock.TrailerFunc() +} + +// TrailerCalls gets all the calls that were made to Trailer. +// Check the length with: +// +// len(mockedWorkflowService_StreamWorkflowsClient.TrailerCalls()) +func (mock *WorkflowService_StreamWorkflowsClientMock) TrailerCalls() []struct { +} { + var calls []struct { + } + mock.lockTrailer.RLock() + calls = mock.calls.Trailer + mock.lockTrailer.RUnlock() + return calls +}