From f325b37e2bbb4ba185b914b4aa4db24ac3013fc2 Mon Sep 17 00:00:00 2001 From: Zhiqiang ZHOU Date: Sun, 15 Sep 2024 12:44:51 -0700 Subject: [PATCH] fix: update controlled cloudflared (#121) --- .../main.go | 9 +- pkg/cloudflare-controller/tunnel-client.go | 8 ++ .../controlled-cloudflared-connector.go | 44 +++++- .../controlled_cloudflared_connector_test.go | 127 ++++++++++++++++++ 4 files changed, 182 insertions(+), 6 deletions(-) create mode 100644 test/integration/controller/controlled_cloudflared_connector_test.go diff --git a/cmd/cloudflare-tunnel-ingress-controller/main.go b/cmd/cloudflare-tunnel-ingress-controller/main.go index 750bdb1..0a1daeb 100644 --- a/cmd/cloudflare-tunnel-ingress-controller/main.go +++ b/cmd/cloudflare-tunnel-ingress-controller/main.go @@ -2,18 +2,19 @@ package main import ( "context" + "log" + "os" + "time" + cloudflarecontroller "github.com/STRRL/cloudflare-tunnel-ingress-controller/pkg/cloudflare-controller" "github.com/STRRL/cloudflare-tunnel-ingress-controller/pkg/controller" "github.com/cloudflare/cloudflare-go" "github.com/go-logr/logr" "github.com/go-logr/stdr" "github.com/spf13/cobra" - "log" - "os" "sigs.k8s.io/controller-runtime/pkg/client/config" crlog "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/manager" - "time" ) type rootCmdFlags struct { @@ -99,7 +100,7 @@ func main() { case <-done: return case _ = <-ticker.C: - err := controller.CreateControlledCloudflaredIfNotExist(ctx, mgr.GetClient(), tunnelClient, options.namespace) + err := controller.CreateOrUpdateControlledCloudflared(ctx, mgr.GetClient(), tunnelClient, options.namespace) if err != nil { logger.WithName("controlled-cloudflared").Error(err, "create controlled cloudflared") } diff --git a/pkg/cloudflare-controller/tunnel-client.go b/pkg/cloudflare-controller/tunnel-client.go index 6f7e395..f9b9fb4 100644 --- a/pkg/cloudflare-controller/tunnel-client.go +++ b/pkg/cloudflare-controller/tunnel-client.go @@ -9,6 +9,14 @@ import ( "github.com/pkg/errors" ) +type TunnelClientInterface interface { + PutExposures(ctx context.Context, exposures []exposure.Exposure) error + TunnelDomain() string + FetchTunnelToken(ctx context.Context) (string, error) +} + +var _ TunnelClientInterface = &TunnelClient{} + type TunnelClient struct { logger logr.Logger cfClient *cloudflare.API diff --git a/pkg/controller/controlled-cloudflared-connector.go b/pkg/controller/controlled-cloudflared-connector.go index 4df0636..14288a2 100644 --- a/pkg/controller/controlled-cloudflared-connector.go +++ b/pkg/controller/controlled-cloudflared-connector.go @@ -12,14 +12,16 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" ) -func CreateControlledCloudflaredIfNotExist( +func CreateOrUpdateControlledCloudflared( ctx context.Context, kubeClient client.Client, - tunnelClient *cloudflarecontroller.TunnelClient, + tunnelClient cloudflarecontroller.TunnelClientInterface, namespace string, ) error { + logger := log.FromContext(ctx) list := appsv1.DeploymentList{} err := kubeClient.List(ctx, &list, &client.ListOptions{ Namespace: namespace, @@ -32,6 +34,43 @@ func CreateControlledCloudflaredIfNotExist( } if len(list.Items) > 0 { + // Check if the existing deployment needs to be updated + existingDeployment := &list.Items[0] + desiredReplicas, err := strconv.ParseInt(os.Getenv("CLOUDFLARED_REPLICA_COUNT"), 10, 32) + if err != nil { + return errors.Wrap(err, "invalid replica count") + } + + needsUpdate := false + if *existingDeployment.Spec.Replicas != int32(desiredReplicas) { + needsUpdate = true + } + + if len(existingDeployment.Spec.Template.Spec.Containers) > 0 { + container := &existingDeployment.Spec.Template.Spec.Containers[0] + if container.Image != os.Getenv("CLOUDFLARED_IMAGE") { + needsUpdate = true + } + if string(container.ImagePullPolicy) != os.Getenv("CLOUDFLARED_IMAGE_PULL_POLICY") { + needsUpdate = true + } + } + + if needsUpdate { + token, err := tunnelClient.FetchTunnelToken(ctx) + if err != nil { + return errors.Wrap(err, "fetch tunnel token") + } + + updatedDeployment := cloudflaredConnectDeploymentTemplating(token, namespace, int32(desiredReplicas)) + existingDeployment.Spec = updatedDeployment.Spec + err = kubeClient.Update(ctx, existingDeployment) + if err != nil { + return errors.Wrap(err, "update controlled-cloudflared-connector deployment") + } + logger.Info("Updated controlled-cloudflared-connector deployment", "namespace", namespace) + } + return nil } @@ -50,6 +89,7 @@ func CreateControlledCloudflaredIfNotExist( if err != nil { return errors.Wrap(err, "create controlled-cloudflared-connector deployment") } + logger.Info("Created controlled-cloudflared-connector deployment", "namespace", namespace) return nil } diff --git a/test/integration/controller/controlled_cloudflared_connector_test.go b/test/integration/controller/controlled_cloudflared_connector_test.go new file mode 100644 index 0000000..f90fd4f --- /dev/null +++ b/test/integration/controller/controlled_cloudflared_connector_test.go @@ -0,0 +1,127 @@ +package controller + +import ( + "context" + "os" + + cloudflarecontroller "github.com/STRRL/cloudflare-tunnel-ingress-controller/pkg/cloudflare-controller" + "github.com/STRRL/cloudflare-tunnel-ingress-controller/pkg/controller" + "github.com/STRRL/cloudflare-tunnel-ingress-controller/pkg/exposure" + "github.com/STRRL/cloudflare-tunnel-ingress-controller/test/fixtures" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" +) + +var _ cloudflarecontroller.TunnelClientInterface = &MockTunnelClient{} + +type MockTunnelClient struct { + FetchTunnelTokenFunc func(ctx context.Context) (string, error) +} + +func (m *MockTunnelClient) PutExposures(ctx context.Context, exposures []exposure.Exposure) error { + return nil +} + +func (m *MockTunnelClient) TunnelDomain() string { + return "mock.tunnel.com" +} + +func (m *MockTunnelClient) FetchTunnelToken(ctx context.Context) (string, error) { + return m.FetchTunnelTokenFunc(ctx) +} + +var _ = Describe("CreateOrUpdateControlledCloudflared", func() { + const testNamespace = "cloudflared-test" + + BeforeEach(func() { + // Set required environment variables + os.Setenv("CLOUDFLARED_REPLICA_COUNT", "2") + os.Setenv("CLOUDFLARED_IMAGE", "cloudflare/cloudflared:latest") + os.Setenv("CLOUDFLARED_IMAGE_PULL_POLICY", "IfNotPresent") + }) + + AfterEach(func() { + // Clean up environment variables + os.Unsetenv("CLOUDFLARED_REPLICA_COUNT") + os.Unsetenv("CLOUDFLARED_IMAGE") + os.Unsetenv("CLOUDFLARED_IMAGE_PULL_POLICY") + }) + + It("should create a new cloudflared deployment", func() { + // Prepare + namespaceFixtures := fixtures.NewKubernetesNamespaceFixtures(testNamespace, kubeClient) + ns, err := namespaceFixtures.Start(ctx) + Expect(err).NotTo(HaveOccurred()) + + defer func() { + err := namespaceFixtures.Stop(ctx) + Expect(err).NotTo(HaveOccurred()) + }() + + mockTunnelClient := &MockTunnelClient{ + FetchTunnelTokenFunc: func(ctx context.Context) (string, error) { + return "mock-token", nil + }, + } + + // Act + err = controller.CreateOrUpdateControlledCloudflared(ctx, kubeClient, mockTunnelClient, ns) + Expect(err).NotTo(HaveOccurred()) + + // Assert + deployment := &appsv1.Deployment{} + err = kubeClient.Get(ctx, types.NamespacedName{ + Namespace: ns, + Name: "controlled-cloudflared-connector", + }, deployment) + Expect(err).NotTo(HaveOccurred()) + + Expect(*deployment.Spec.Replicas).To(Equal(int32(2))) + Expect(deployment.Spec.Template.Spec.Containers[0].Image).To(Equal("cloudflare/cloudflared:latest")) + Expect(deployment.Spec.Template.Spec.Containers[0].ImagePullPolicy).To(Equal(v1.PullPolicy("IfNotPresent"))) + }) + + It("should update an existing cloudflared deployment", func() { + // Prepare + namespaceFixtures := fixtures.NewKubernetesNamespaceFixtures(testNamespace, kubeClient) + ns, err := namespaceFixtures.Start(ctx) + Expect(err).NotTo(HaveOccurred()) + + defer func() { + err := namespaceFixtures.Stop(ctx) + Expect(err).NotTo(HaveOccurred()) + }() + + mockTunnelClient := &MockTunnelClient{ + FetchTunnelTokenFunc: func(ctx context.Context) (string, error) { + return "mock-token", nil + }, + } + + // Create initial deployment + err = controller.CreateOrUpdateControlledCloudflared(ctx, kubeClient, mockTunnelClient, ns) + Expect(err).NotTo(HaveOccurred()) + + // Change environment variables + os.Setenv("CLOUDFLARED_REPLICA_COUNT", "3") + os.Setenv("CLOUDFLARED_IMAGE", "cloudflare/cloudflared:2022.3.0") + + // Act + err = controller.CreateOrUpdateControlledCloudflared(ctx, kubeClient, mockTunnelClient, ns) + Expect(err).NotTo(HaveOccurred()) + + // Assert + deployment := &appsv1.Deployment{} + err = kubeClient.Get(ctx, types.NamespacedName{ + Namespace: ns, + Name: "controlled-cloudflared-connector", + }, deployment) + Expect(err).NotTo(HaveOccurred()) + + Expect(*deployment.Spec.Replicas).To(Equal(int32(3))) + Expect(deployment.Spec.Template.Spec.Containers[0].Image).To(Equal("cloudflare/cloudflared:2022.3.0")) + }) +})