Skip to content

Commit

Permalink
u
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisdoherty4 committed May 17, 2023
1 parent ac043bd commit 1e575b5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 52 deletions.
22 changes: 6 additions & 16 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ func (agent *Agent) Start(ctx context.Context) error {
}

if agent.Runtime == nil {
return errors.New("agent.Runtime must be set before calling Start()")
//nolint:stylecheck // Runtime is a field of agent.
return errors.New("Runtime field must be set before calling Start()")
}

if agent.Log.GetSink() == nil {
agent.Log = logr.Discard()
}

agent.Log = agent.Log.WithValues("agent_id", agent.ID)
Expand All @@ -59,21 +64,6 @@ func (agent *Agent) Start(ctx context.Context) error {
agent.sem <- struct{}{}

return agent.Transport.Start(ctx, agent.ID, agent)

// // Launch the transport ensuring we can recover any errors.
// transportErr := make(chan error, 1)
// go func() {
// agent.Log.Info("Starting agent")
// transportErr <- agent.Transport.Start(ctx, agent.ID, agent)
// }()

// select {
// case err := <-transportErr:
// return fmt.Errorf("transport: %w", err)
// case <-ctx.Done():
// // Cancel any running workflows.
// return ctx.Err()
// }
}

// HandleWorkflow satisfies transport.
Expand Down
91 changes: 55 additions & 36 deletions internal/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"strings"
"testing"
"time"

"github.com/go-logr/logr"
"github.com/go-logr/zapr"
Expand Down Expand Up @@ -59,6 +60,14 @@ func TestAgent_InvalidStart(t *testing.T) {
Runtime: runtime.Noop(),
},
},
{
Name: "NoLogger",
Agent: agent.Agent{
ID: "1234",
Transport: transport.Noop(),
Runtime: runtime.Noop(),
},
},
}

for _, tc := range cases {
Expand All @@ -76,14 +85,11 @@ func TestAgent_InvalidStart(t *testing.T) {
}
}

// The goal of this test is to ensure the agent rejects concurrent workflows.
func TestAgent_ConcurrentWorkflows(t *testing.T) {
// The goal of this test is to ensure the agent rejects concurrent workflows. We have to
// build a valid agent because it will also reject calls to HandleWorkflow without first
// starting the Agent.

logger := zapr.NewLogger(zap.Must(zap.NewDevelopment()))

recorder := event.NoopRecorder()
trnport := transport.Noop()

wflw := workflow.Workflow{
ID: "1234",
Expand All @@ -96,15 +102,8 @@ func TestAgent_ConcurrentWorkflows(t *testing.T) {
},
}

trnport := agent.TransportMock{
StartFunc: func(ctx context.Context, agentID string, handler transport.WorkflowHandler) error {
handler.HandleWorkflow(ctx, wflw, recorder)
return nil
},
}

// Started is used to indicate the runtime has received the workflow.
started := make(chan struct{})

rntime := agent.ContainerRuntimeMock{
RunFunc: func(ctx context.Context, action workflow.Action) error {
started <- struct{}{}
Expand All @@ -120,26 +119,28 @@ func TestAgent_ConcurrentWorkflows(t *testing.T) {
ID: "1234",
}

// Build a cancellable context so we can tear the goroutine down.
// HandleWorkflow will reject us if we haven't started the agent first.
if err := agnt.Start(context.Background()); err != nil {
t.Fatal(err)
}

errs := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
// Build a cancellable context so we can tear everything down. The timeout is guesswork but
// this test shouldn't take long.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

go func() { errs <- agnt.Start(ctx) }()
// Handle the first workflow and wait for it to start.
agnt.HandleWorkflow(ctx, wflw, recorder)

// Await either an error or the mock runtime to tell us its stated.
// Wait for the container runtime to start.
select {
case err := <-errs:
t.Fatalf("Unexpected error: %v", err)
case <-started:
case <-ctx.Done():
t.Fatal(ctx.Err())
}

// Attempt to fire off another workflow.
// Attempt to fire off a second workflow.
agnt.HandleWorkflow(ctx, wflw, recorder)
// if err != nil {
// t.Fatalf("Unexpected error: %v", err)
// }

// Ensure the latest event recorded is a event.WorkflowRejected.
calls := recorder.RecordEventCalls()
Expand Down Expand Up @@ -387,7 +388,7 @@ func TestAgent_HandlingWorkflows(t *testing.T) {
Errors: map[string]ReasonAndMessage{
"1": {
Reason: "TestReason",
Message: `invalid
Message: `invalid
message`,
},
},
Expand All @@ -408,14 +409,7 @@ message`,

for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
recorder := event.NoopRecorder()

trnport := agent.TransportMock{
StartFunc: func(ctx context.Context, agentID string, handler transport.WorkflowHandler) error {
handler.HandleWorkflow(ctx, tc.Workflow, recorder)
return nil
},
}
trnport := transport.Noop()

rntime := agent.ContainerRuntimeMock{
RunFunc: func(ctx context.Context, action workflow.Action) error {
Expand All @@ -426,18 +420,43 @@ message`,
},
}

// The event recorder is what tells us the workflow has finished executing so we use it
// to check for the last expected action.
lastEventReceived := make(chan struct{})
recorder := event.RecorderMock{
RecordEventFunc: func(contextMoqParam context.Context, event event.Event) {
if cmp.Equal(event, tc.Events[len(tc.Events)-1]) {
lastEventReceived <- struct{}{}
}
},
}

// Create and start the agent as start is a prereq to calling HandleWorkflow().
agnt := agent.Agent{
Log: logger,
Transport: &trnport,
Runtime: &rntime,
ID: "1234",
}

err := agnt.Start(context.Background())
if err != nil {
if err := agnt.Start(context.Background()); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Configure a timeout of 5 seconds, this test shouldn't take long.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Handle the workflow
agnt.HandleWorkflow(ctx, tc.Workflow, &recorder)

// Wait for the last expected event or timeout.
select {
case <-lastEventReceived:
case <-ctx.Done():
t.Fatal(ctx.Err())
}

// Validate all events received are what we expected.
calls := recorder.RecordEventCalls()
if len(calls) != len(tc.Events) {
t.Fatalf("Expected %v events; Received %v\n%+v", len(tc.Events), len(calls), calls)
Expand Down

0 comments on commit 1e575b5

Please sign in to comment.