diff --git a/agent/consul/server.go b/agent/consul/server.go index 52236e8b5d004..9a09be4172139 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -844,6 +844,8 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom func (s *Server) registerResources() { catalog.RegisterTypes(s.typeRegistry) + catalog.RegisterControllers(s.controllerManager) + mesh.RegisterTypes(s.typeRegistry) reaper.RegisterControllers(s.controllerManager) diff --git a/internal/catalog/exports.go b/internal/catalog/exports.go index 4f3ddb6a9d8d7..39d6c44e66596 100644 --- a/internal/catalog/exports.go +++ b/internal/catalog/exports.go @@ -4,7 +4,9 @@ package catalog import ( + "github.com/hashicorp/consul/internal/catalog/internal/controllers" "github.com/hashicorp/consul/internal/catalog/internal/types" + "github.com/hashicorp/consul/internal/controller" "github.com/hashicorp/consul/internal/resource" ) @@ -43,3 +45,9 @@ var ( func RegisterTypes(r resource.Registry) { types.Register(r) } + +// RegisterControllers registers controllers for the catalog types with +// the given controller Manager. +func RegisterControllers(mgr *controller.Manager) { + controllers.Register(mgr) +} diff --git a/internal/catalog/internal/controllers/nodehealth/controller.go b/internal/catalog/internal/controllers/nodehealth/controller.go new file mode 100644 index 0000000000000..a439d57756d82 --- /dev/null +++ b/internal/catalog/internal/controllers/nodehealth/controller.go @@ -0,0 +1,105 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package nodehealth + +import ( + "context" + "fmt" + + "github.com/hashicorp/consul/internal/catalog/internal/types" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +func NodeHealthController() controller.Controller { + return controller.ForType(types.NodeType). + WithWatch(types.HealthStatusType, controller.MapOwnerFiltered(types.NodeType)). + WithReconciler(&nodeHealthReconciler{}) +} + +type nodeHealthReconciler struct{} + +func (r *nodeHealthReconciler) Reconcile(ctx context.Context, rt controller.Runtime, req controller.Request) error { + // read the workload + rsp, err := rt.Client.Read(ctx, &pbresource.ReadRequest{Id: req.ID}) + switch { + case status.Code(err) == codes.NotFound: + return nil + case err != nil: + return err + } + + res := rsp.Resource + + health, err := getNodeHealth(ctx, rt, req.ID) + if err != nil { + return err + } + + message := NodeHealthyMessage + statusState := pbresource.Condition_STATE_TRUE + if health != pbcatalog.Health_HEALTH_PASSING { + statusState = pbresource.Condition_STATE_FALSE + message = NodeUnhealthyMessage + } + + newStatus := &pbresource.Status{ + ObservedGeneration: res.Generation, + Conditions: []*pbresource.Condition{ + { + Type: StatusConditionHealthy, + State: statusState, + Reason: health.String(), + Message: message, + }, + }, + } + + if resource.EqualStatus(res.Status[StatusKey], newStatus, false) { + return nil + } + + _, err = rt.Client.WriteStatus(ctx, &pbresource.WriteStatusRequest{ + Id: res.Id, + Key: StatusKey, + Status: newStatus, + }) + + return err +} + +func getNodeHealth(ctx context.Context, rt controller.Runtime, nodeRef *pbresource.ID) (pbcatalog.Health, error) { + rsp, err := rt.Client.ListByOwner(ctx, &pbresource.ListByOwnerRequest{ + Owner: nodeRef, + }) + + if err != nil { + return pbcatalog.Health_HEALTH_CRITICAL, err + } + + health := pbcatalog.Health_HEALTH_PASSING + + for _, res := range rsp.Resources { + if proto.Equal(res.Id.Type, types.HealthStatusType) { + var hs pbcatalog.HealthStatus + if err := res.Data.UnmarshalTo(&hs); err != nil { + // This should be impossible as the resource service + type validations the + // catalog is performing will ensure that no data gets written where unmarshalling + // to this type will error. + return health, fmt.Errorf("error unmarshalling health status data: %w", err) + } + + if hs.Status > health { + health = hs.Status + } + } + } + + return health, nil +} diff --git a/internal/catalog/internal/controllers/nodehealth/controller_test.go b/internal/catalog/internal/controllers/nodehealth/controller_test.go new file mode 100644 index 0000000000000..c3d9dd2d90e43 --- /dev/null +++ b/internal/catalog/internal/controllers/nodehealth/controller_test.go @@ -0,0 +1,365 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package nodehealth + +import ( + "context" + "fmt" + "testing" + + svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" + "github.com/hashicorp/consul/internal/catalog/internal/types" + "github.com/hashicorp/consul/internal/controller" + "github.com/hashicorp/consul/internal/resource/resourcetest" + pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v1alpha1" + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/hashicorp/consul/proto/private/prototest" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/hashicorp/consul/sdk/testutil/retry" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var ( + nodeData = &pbcatalog.Node{ + Addresses: []*pbcatalog.NodeAddress{ + { + Host: "127.0.0.1", + }, + }, + } + + dnsPolicyData = &pbcatalog.DNSPolicy{ + Workloads: &pbcatalog.WorkloadSelector{ + Prefixes: []string{""}, + }, + Weights: &pbcatalog.Weights{ + Passing: 1, + Warning: 1, + }, + } +) + +func resourceID(rtype *pbresource.Type, name string) *pbresource.ID { + return &pbresource.ID{ + Type: rtype, + Tenancy: &pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + }, + Name: name, + } +} + +type nodeHealthControllerTestSuite struct { + suite.Suite + + resourceClient pbresource.ResourceServiceClient + runtime controller.Runtime + + ctl nodeHealthReconciler + + nodeNoHealth *pbresource.ID + nodePassing *pbresource.ID + nodeWarning *pbresource.ID + nodeCritical *pbresource.ID + nodeMaintenance *pbresource.ID +} + +func (suite *nodeHealthControllerTestSuite) SetupTest() { + suite.resourceClient = svctest.RunResourceService(suite.T(), types.Register) + suite.runtime = controller.Runtime{Client: suite.resourceClient} + + // The rest of the setup will be to prime the resource service with some data + suite.nodeNoHealth = resourcetest.Resource(types.NodeType, "test-node-no-health"). + WithData(suite.T(), nodeData). + Write(suite.T(), suite.resourceClient).Id + + suite.nodePassing = resourcetest.Resource(types.NodeType, "test-node-passing"). + WithData(suite.T(), nodeData). + Write(suite.T(), suite.resourceClient).Id + + suite.nodeWarning = resourcetest.Resource(types.NodeType, "test-node-warning"). + WithData(suite.T(), nodeData). + Write(suite.T(), suite.resourceClient).Id + + suite.nodeCritical = resourcetest.Resource(types.NodeType, "test-node-critical"). + WithData(suite.T(), nodeData). + Write(suite.T(), suite.resourceClient).Id + + suite.nodeMaintenance = resourcetest.Resource(types.NodeType, "test-node-maintenance"). + WithData(suite.T(), nodeData). + Write(suite.T(), suite.resourceClient).Id + + nodeHealthDesiredStatus := map[string]pbcatalog.Health{ + suite.nodePassing.Name: pbcatalog.Health_HEALTH_PASSING, + suite.nodeWarning.Name: pbcatalog.Health_HEALTH_WARNING, + suite.nodeCritical.Name: pbcatalog.Health_HEALTH_CRITICAL, + suite.nodeMaintenance.Name: pbcatalog.Health_HEALTH_MAINTENANCE, + } + + // In order to exercise the behavior to ensure that its not a last-status-wins sort of thing + // we are strategically naming health statuses so that they will be returned in an order with + // the most precedent status being in the middle of the list. This will ensure that statuses + // seen later can overide a previous status and that statuses seen later do not override if + // they would lower the overall status such as going from critical -> warning. + precedenceHealth := []pbcatalog.Health{ + pbcatalog.Health_HEALTH_PASSING, + pbcatalog.Health_HEALTH_WARNING, + pbcatalog.Health_HEALTH_CRITICAL, + pbcatalog.Health_HEALTH_MAINTENANCE, + pbcatalog.Health_HEALTH_CRITICAL, + pbcatalog.Health_HEALTH_WARNING, + pbcatalog.Health_HEALTH_PASSING, + } + + for _, node := range []*pbresource.ID{suite.nodePassing, suite.nodeWarning, suite.nodeCritical, suite.nodeMaintenance} { + for idx, health := range precedenceHealth { + if nodeHealthDesiredStatus[node.Name] >= health { + resourcetest.Resource(types.HealthStatusType, fmt.Sprintf("test-check-%s-%d", node.Name, idx)). + WithData(suite.T(), &pbcatalog.HealthStatus{Type: "tcp", Status: health}). + WithOwner(node). + Write(suite.T(), suite.resourceClient) + } + } + } + + // create a DNSPolicy to be owned by the node. The type doesn't really matter it just needs + // to be something that doesn't care about its owner. All we want to prove is that we are + // filtering out non-HealthStatus types appropriately. + resourcetest.Resource(types.DNSPolicyType, "test-policy"). + WithData(suite.T(), dnsPolicyData). + WithOwner(suite.nodeNoHealth). + Write(suite.T(), suite.resourceClient) +} + +func (suite *nodeHealthControllerTestSuite) TestGetNodeHealthListError() { + // This resource id references a resource type that will not be + // registered with the resource service. The ListByOwner call + // should produce an InvalidArgument error. This test is meant + // to validate how that error is handled (its propagated back + // to the caller) + ref := resourceID( + &pbresource.Type{Group: "not", GroupVersion: "v1", Kind: "found"}, + "irrelevant", + ) + health, err := getNodeHealth(context.Background(), suite.runtime, ref) + require.Equal(suite.T(), pbcatalog.Health_HEALTH_CRITICAL, health) + require.Error(suite.T(), err) + require.Equal(suite.T(), codes.InvalidArgument, status.Code(err)) +} + +func (suite *nodeHealthControllerTestSuite) TestGetNodeHealthNoNode() { + // This test is meant to ensure that when the node doesn't exist + // no error is returned but also no data is. The default passing + // status should then be returned in the same manner as the node + // existing but with no associated HealthStatus resources. + ref := resourceID(types.NodeType, "foo") + ref.Uid = ulid.Make().String() + health, err := getNodeHealth(context.Background(), suite.runtime, ref) + + require.NoError(suite.T(), err) + require.Equal(suite.T(), pbcatalog.Health_HEALTH_PASSING, health) +} + +func (suite *nodeHealthControllerTestSuite) TestGetNodeHealthNoStatus() { + health, err := getNodeHealth(context.Background(), suite.runtime, suite.nodeNoHealth) + require.NoError(suite.T(), err) + require.Equal(suite.T(), pbcatalog.Health_HEALTH_PASSING, health) +} + +func (suite *nodeHealthControllerTestSuite) TestGetNodeHealthPassingStatus() { + health, err := getNodeHealth(context.Background(), suite.runtime, suite.nodePassing) + require.NoError(suite.T(), err) + require.Equal(suite.T(), pbcatalog.Health_HEALTH_PASSING, health) +} + +func (suite *nodeHealthControllerTestSuite) TestGetNodeHealthCriticalStatus() { + health, err := getNodeHealth(context.Background(), suite.runtime, suite.nodeCritical) + require.NoError(suite.T(), err) + require.Equal(suite.T(), pbcatalog.Health_HEALTH_CRITICAL, health) +} + +func (suite *nodeHealthControllerTestSuite) TestGetNodeHealthWarningStatus() { + health, err := getNodeHealth(context.Background(), suite.runtime, suite.nodeWarning) + require.NoError(suite.T(), err) + require.Equal(suite.T(), pbcatalog.Health_HEALTH_WARNING, health) +} + +func (suite *nodeHealthControllerTestSuite) TestGetNodeHealthMaintenanceStatus() { + health, err := getNodeHealth(context.Background(), suite.runtime, suite.nodeMaintenance) + require.NoError(suite.T(), err) + require.Equal(suite.T(), pbcatalog.Health_HEALTH_MAINTENANCE, health) +} + +func (suite *nodeHealthControllerTestSuite) TestReconcileNodeNotFound() { + // This test ensures that removed nodes are ignored. In particular we don't + // want to propagate the error and indefinitely keep re-reconciling in this case. + err := suite.ctl.Reconcile(context.Background(), suite.runtime, controller.Request{ + ID: resourceID(types.NodeType, "not-found"), + }) + require.NoError(suite.T(), err) +} + +func (suite *nodeHealthControllerTestSuite) TestReconcilePropagateReadError() { + // This test aims to ensure that errors other than NotFound errors coming + // from the initial resource read get propagated. This cases is very unrealsitic + // as the controller should not have given us a request ID for a resource type + // that doesn't exist but this was the easiest way I could think of to synthesize + // a Read error. + ref := resourceID( + &pbresource.Type{Group: "not", GroupVersion: "v1", Kind: "found"}, + "irrelevant", + ) + + err := suite.ctl.Reconcile(context.Background(), suite.runtime, controller.Request{ + ID: ref, + }) + require.Error(suite.T(), err) + require.Equal(suite.T(), codes.InvalidArgument, status.Code(err)) +} + +func (suite *nodeHealthControllerTestSuite) testReconcileStatus(id *pbresource.ID, expectedStatus *pbresource.Condition) *pbresource.Resource { + suite.T().Helper() + + err := suite.ctl.Reconcile(context.Background(), suite.runtime, controller.Request{ + ID: id, + }) + require.NoError(suite.T(), err) + + rsp, err := suite.resourceClient.Read(context.Background(), &pbresource.ReadRequest{ + Id: id, + }) + require.NoError(suite.T(), err) + + nodeHealthStatus, found := rsp.Resource.Status[StatusKey] + require.True(suite.T(), found) + require.Equal(suite.T(), rsp.Resource.Generation, nodeHealthStatus.ObservedGeneration) + require.Len(suite.T(), nodeHealthStatus.Conditions, 1) + prototest.AssertDeepEqual(suite.T(), + nodeHealthStatus.Conditions[0], + expectedStatus) + + return rsp.Resource +} + +func (suite *nodeHealthControllerTestSuite) TestReconcile_StatusPassing() { + suite.testReconcileStatus(suite.nodePassing, &pbresource.Condition{ + Type: StatusConditionHealthy, + State: pbresource.Condition_STATE_TRUE, + Reason: "HEALTH_PASSING", + Message: NodeHealthyMessage, + }) +} + +func (suite *nodeHealthControllerTestSuite) TestReconcile_StatusWarning() { + suite.testReconcileStatus(suite.nodeWarning, &pbresource.Condition{ + Type: StatusConditionHealthy, + State: pbresource.Condition_STATE_FALSE, + Reason: "HEALTH_WARNING", + Message: NodeUnhealthyMessage, + }) +} + +func (suite *nodeHealthControllerTestSuite) TestReconcile_StatusCritical() { + suite.testReconcileStatus(suite.nodeCritical, &pbresource.Condition{ + Type: StatusConditionHealthy, + State: pbresource.Condition_STATE_FALSE, + Reason: "HEALTH_CRITICAL", + Message: NodeUnhealthyMessage, + }) +} + +func (suite *nodeHealthControllerTestSuite) TestReconcile_StatusMaintenance() { + suite.testReconcileStatus(suite.nodeMaintenance, &pbresource.Condition{ + Type: StatusConditionHealthy, + State: pbresource.Condition_STATE_FALSE, + Reason: "HEALTH_MAINTENANCE", + Message: NodeUnhealthyMessage, + }) +} + +func (suite *nodeHealthControllerTestSuite) TestReconcile_AvoidRereconciliationWrite() { + res1 := suite.testReconcileStatus(suite.nodeWarning, &pbresource.Condition{ + Type: StatusConditionHealthy, + State: pbresource.Condition_STATE_FALSE, + Reason: "HEALTH_WARNING", + Message: NodeUnhealthyMessage, + }) + + res2 := suite.testReconcileStatus(suite.nodeWarning, &pbresource.Condition{ + Type: StatusConditionHealthy, + State: pbresource.Condition_STATE_FALSE, + Reason: "HEALTH_WARNING", + Message: NodeUnhealthyMessage, + }) + + // If another status write was performed then the versions would differ. This + // therefore proves that after a second reconciliation without any change in status + // that we are not making subsequent status writes. + require.Equal(suite.T(), res1.Version, res2.Version) +} + +func (suite *nodeHealthControllerTestSuite) waitForReconciliation(id *pbresource.ID, reason string) { + suite.T().Helper() + + retry.Run(suite.T(), func(r *retry.R) { + rsp, err := suite.resourceClient.Read(context.Background(), &pbresource.ReadRequest{ + Id: id, + }) + require.NoError(r, err) + + nodeHealthStatus, found := rsp.Resource.Status[StatusKey] + require.True(r, found) + require.Equal(r, rsp.Resource.Generation, nodeHealthStatus.ObservedGeneration) + require.Len(r, nodeHealthStatus.Conditions, 1) + require.Equal(r, reason, nodeHealthStatus.Conditions[0].Reason) + }) +} +func (suite *nodeHealthControllerTestSuite) TestController() { + // create the controller manager + mgr := controller.NewManager(suite.resourceClient, testutil.Logger(suite.T())) + + // register our controller + mgr.Register(NodeHealthController()) + mgr.SetRaftLeader(true) + ctx, cancel := context.WithCancel(context.Background()) + suite.T().Cleanup(cancel) + + // run the manager + go mgr.Run(ctx) + + // ensure that the node health eventually gets set. + suite.waitForReconciliation(suite.nodePassing, "HEALTH_PASSING") + + // rewrite the resource - this will cause the nodes health + // to be rereconciled but wont result in any health change + resourcetest.Resource(types.NodeType, suite.nodePassing.Name). + WithData(suite.T(), &pbcatalog.Node{ + Addresses: []*pbcatalog.NodeAddress{ + { + Host: "198.18.0.1", + }, + }, + }). + Write(suite.T(), suite.resourceClient) + + // wait for rereconciliation to happen + suite.waitForReconciliation(suite.nodePassing, "HEALTH_PASSING") + + resourcetest.Resource(types.HealthStatusType, "failure"). + WithData(suite.T(), &pbcatalog.HealthStatus{Type: "fake", Status: pbcatalog.Health_HEALTH_CRITICAL}). + WithOwner(suite.nodePassing). + Write(suite.T(), suite.resourceClient) + + suite.waitForReconciliation(suite.nodePassing, "HEALTH_CRITICAL") +} + +func TestNodeHealthController(t *testing.T) { + suite.Run(t, new(nodeHealthControllerTestSuite)) +} diff --git a/internal/catalog/internal/controllers/nodehealth/status.go b/internal/catalog/internal/controllers/nodehealth/status.go new file mode 100644 index 0000000000000..8cb989101dad6 --- /dev/null +++ b/internal/catalog/internal/controllers/nodehealth/status.go @@ -0,0 +1,12 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package nodehealth + +const ( + StatusKey = "consul.io/node-health" + StatusConditionHealthy = "healthy" + + NodeHealthyMessage = "All node health checks are passing" + NodeUnhealthyMessage = "One or more node health checks are not passing" +) diff --git a/internal/catalog/internal/controllers/register.go b/internal/catalog/internal/controllers/register.go new file mode 100644 index 0000000000000..b4b6f190f74b8 --- /dev/null +++ b/internal/catalog/internal/controllers/register.go @@ -0,0 +1,13 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package controllers + +import ( + "github.com/hashicorp/consul/internal/catalog/internal/controllers/nodehealth" + "github.com/hashicorp/consul/internal/controller" +) + +func Register(mgr *controller.Manager) { + mgr.Register(nodehealth.NodeHealthController()) +} diff --git a/internal/controller/api.go b/internal/controller/api.go index 7a2e89be46412..c2cc0925da1b4 100644 --- a/internal/controller/api.go +++ b/internal/controller/api.go @@ -169,6 +169,29 @@ func MapOwner(_ context.Context, _ Runtime, res *pbresource.Resource) ([]Request return reqs, nil } +func MapOwnerFiltered(filter *pbresource.Type) DependencyMapper { + return func(_ context.Context, _ Runtime, res *pbresource.Resource) ([]Request, error) { + if res.Owner == nil { + return nil, nil + } + + ownerType := res.Owner.GetType() + if ownerType.Group != filter.Group { + return nil, nil + } + + if ownerType.GroupVersion != filter.GroupVersion { + return nil, nil + } + + if ownerType.Kind != filter.Kind { + return nil, nil + } + + return []Request{{ID: res.Owner}}, nil + } +} + // Placement determines where and how many replicas of the controller will run. type Placement int diff --git a/internal/controller/api_test.go b/internal/controller/api_test.go index 2006664b20fb7..51bb62843f7df 100644 --- a/internal/controller/api_test.go +++ b/internal/controller/api_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/require" svctest "github.com/hashicorp/consul/agent/grpc-external/services/resource/testing" @@ -266,3 +267,76 @@ func testContext(t *testing.T) context.Context { return ctx } + +func resourceID(group string, version string, kind string, name string) *pbresource.ID { + return &pbresource.ID{ + Type: &pbresource.Type{ + Group: group, + GroupVersion: version, + Kind: kind, + }, + Tenancy: &pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + }, + Name: name, + } +} + +func TestMapOwnerFiltered(t *testing.T) { + mapper := controller.MapOwnerFiltered(&pbresource.Type{ + Group: "foo", + GroupVersion: "v1", + Kind: "bar", + }) + + type testCase struct { + owner *pbresource.ID + matches bool + } + + cases := map[string]testCase{ + "nil-owner": { + owner: nil, + matches: false, + }, + "group-mismatch": { + owner: resourceID("other", "v1", "bar", "irrelevant"), + matches: false, + }, + "group-version-mismatch": { + owner: resourceID("foo", "v2", "bar", "irrelevant"), + matches: false, + }, + "kind-mismatch": { + owner: resourceID("foo", "v1", "baz", "irrelevant"), + matches: false, + }, + "match": { + owner: resourceID("foo", "v1", "bar", "irrelevant"), + matches: true, + }, + } + + for name, tcase := range cases { + t.Run(name, func(t *testing.T) { + // the runtime is not used by the mapper so its fine to pass an empty struct + req, err := mapper(context.Background(), controller.Runtime{}, &pbresource.Resource{ + Id: resourceID("foo", "v1", "other", "x"), + Owner: tcase.owner, + }) + + // The mapper has no error paths at present + require.NoError(t, err) + + if tcase.matches { + require.NotNil(t, req) + require.Len(t, req, 1) + prototest.AssertDeepEqual(t, req[0].ID, tcase.owner, cmpopts.EquateEmpty()) + } else { + require.Nil(t, req) + } + }) + } +} diff --git a/internal/resource/resourcetest/builder.go b/internal/resource/resourcetest/builder.go new file mode 100644 index 0000000000000..46039e209a70f --- /dev/null +++ b/internal/resource/resourcetest/builder.go @@ -0,0 +1,136 @@ +package resourcetest + +import ( + "context" + "testing" + + "github.com/hashicorp/consul/proto-public/pbresource" + "github.com/oklog/ulid/v2" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" +) + +type resourceBuilder struct { + resource *pbresource.Resource + statuses map[string]*pbresource.Status + dontCleanup bool +} + +func Resource(rtype *pbresource.Type, name string) *resourceBuilder { + return &resourceBuilder{ + resource: &pbresource.Resource{ + Id: &pbresource.ID{ + Type: &pbresource.Type{ + Group: rtype.Group, + GroupVersion: rtype.GroupVersion, + Kind: rtype.Kind, + }, + Tenancy: &pbresource.Tenancy{ + Partition: "default", + Namespace: "default", + PeerName: "local", + }, + Name: name, + }, + }, + } +} + +func (b *resourceBuilder) WithData(t *testing.T, data protoreflect.ProtoMessage) *resourceBuilder { + anyData, err := anypb.New(data) + require.NoError(t, err) + b.resource.Data = anyData + return b +} + +func (b *resourceBuilder) WithOwner(id *pbresource.ID) *resourceBuilder { + b.resource.Owner = id + return b +} + +func (b *resourceBuilder) WithStatus(key string, status *pbresource.Status) *resourceBuilder { + if b.statuses == nil { + b.statuses = make(map[string]*pbresource.Status) + } + b.statuses[key] = status + return b +} + +func (b *resourceBuilder) WithoutCleanup() *resourceBuilder { + b.dontCleanup = true + return b +} + +func (b *resourceBuilder) WithGeneration(gen string) *resourceBuilder { + b.resource.Generation = gen + return b +} + +func (b *resourceBuilder) Build() *pbresource.Resource { + // clone the resource so we can add on status information + res := proto.Clone(b.resource).(*pbresource.Resource) + + // fill in the generation if empty to make it look like + // a real managed resource + if res.Generation == "" { + res.Generation = ulid.Make().String() + } + + // Now create the status map + res.Status = make(map[string]*pbresource.Status) + for key, original := range b.statuses { + status := &pbresource.Status{ + ObservedGeneration: res.Generation, + Conditions: original.Conditions, + } + res.Status[key] = status + } + + return res +} + +func (b *resourceBuilder) Write(t *testing.T, client pbresource.ResourceServiceClient) *pbresource.Resource { + res := b.resource + + rsp, err := client.Write(context.Background(), &pbresource.WriteRequest{ + Resource: res, + }) + + require.NoError(t, err) + + if !b.dontCleanup { + t.Cleanup(func() { + _, err := client.Delete(context.Background(), &pbresource.DeleteRequest{ + Id: rsp.Resource.Id, + }) + require.NoError(t, err) + }) + } + + if len(b.statuses) == 0 { + return rsp.Resource + } + + for key, original := range b.statuses { + status := &pbresource.Status{ + ObservedGeneration: rsp.Resource.Generation, + Conditions: original.Conditions, + } + _, err := client.WriteStatus(context.Background(), &pbresource.WriteStatusRequest{ + Id: rsp.Resource.Id, + Key: key, + Status: status, + }) + require.NoError(t, err) + } + + readResp, err := client.Read(context.Background(), &pbresource.ReadRequest{ + Id: rsp.Resource.Id, + }) + + require.NoError(t, err) + + return readResp.Resource +}