diff --git a/cmd/aws-ebs-csi-driver-operator/main.go b/cmd/aws-ebs-csi-driver-operator/main.go index 456c82576..5ac7b8334 100644 --- a/cmd/aws-ebs-csi-driver-operator/main.go +++ b/cmd/aws-ebs-csi-driver-operator/main.go @@ -2,15 +2,19 @@ package main import ( "context" + "fmt" "os" "github.com/openshift/csi-operator/pkg/driver/aws-ebs" "github.com/openshift/library-go/pkg/controller/controllercmd" "github.com/spf13/cobra" "k8s.io/component-base/cli" + "k8s.io/klog/v2" + configclient "github.com/openshift/client-go/config/clientset/versioned" "github.com/openshift/csi-operator/pkg/operator" "github.com/openshift/csi-operator/pkg/version" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" ) func main() { @@ -49,6 +53,31 @@ func NewOperatorCommand() *cobra.Command { } func runCSIDriverOperator(ctx context.Context, controllerConfig *controllercmd.ControllerContext) error { + klog.Info("Starting AWS EBS CSI Driver Operator") + opConfig := aws_ebs.GetAWSEBSOperatorConfig() + + configClient, err := configclient.NewForConfig(controllerConfig.KubeConfig) + if err != nil { + klog.Errorf("Failed to create config client: %v", err) + return fmt.Errorf("failed to create config client: %v", err) + } + + coreClient, err := corev1.NewForConfig(controllerConfig.KubeConfig) + if err != nil { + klog.Errorf("Failed to create core client: %v", err) + return fmt.Errorf("failed to create core client: %v", err) + } + + ebsTagsController, err := aws_ebs.NewEBSVolumeTagController(configClient, coreClient) + if err != nil { + klog.Errorf("Failed to create EBS volume tag controller: %v", err) + return fmt.Errorf("failed to create EBS volume tag controller: %v", err) + } + + go ebsTagsController.Run(ctx) + + klog.Info("EBS Volume Tag Controller is running") + return operator.RunOperator(ctx, controllerConfig, *guestKubeconfig, opConfig) } diff --git a/pkg/driver/aws-ebs/aws_ebs_tags_controller.go b/pkg/driver/aws-ebs/aws_ebs_tags_controller.go new file mode 100644 index 000000000..6d111dd92 --- /dev/null +++ b/pkg/driver/aws-ebs/aws_ebs_tags_controller.go @@ -0,0 +1,268 @@ +package aws_ebs + +import ( + "context" + "fmt" + "reflect" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + corev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" + "k8s.io/klog/v2" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + + configv1 "github.com/openshift/api/config/v1" + configclient "github.com/openshift/client-go/config/clientset/versioned" +) + +const ( + awsSecretNamespace = "openshift-cluster-csi-drivers" + awsSecretName = "ebs-cloud-credentials" + infrastructureResource = "cluster" + driverName = "ebs.csi.aws.com" +) + +// EBSVolumeTagController is the custom controller +type EBSVolumeTagController struct { + configClient configclient.Interface + coreClient corev1.CoreV1Interface + queue workqueue.RateLimitingInterface + informer cache.SharedIndexInformer + awsEC2Client *ec2.EC2 +} + +// NewEBSVolumeTagController initializes the controller and sets up the AWS session using credentials from a Kubernetes secret +func NewEBSVolumeTagController(configClient configclient.Interface, coreClient corev1.CoreV1Interface) (*EBSVolumeTagController, error) { + queue := workqueue.NewRateLimitingQueue(workqueue.DefaultControllerRateLimiter()) + + awsRegion, err := getAWSRegionFromInfrastructure(configClient) + if err != nil { + return nil, fmt.Errorf("error retrieving AWS region from infrastructure: %v", err) + } + + // Initialize AWS EC2 client using the credentials from the secret + awsEC2Client, err := getEC2Client(context.TODO(), coreClient, awsRegion) + if err != nil { + return nil, fmt.Errorf("error creating AWS EC2 client: %v", err) + } + + // Create a listerWatcher for the Infrastructure resource + listerWatcher := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + return configClient.ConfigV1().Infrastructures().List(context.TODO(), options) + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + options.FieldSelector = fields.OneTermEqualSelector("metadata.name", infrastructureResource).String() + return configClient.ConfigV1().Infrastructures().Watch(context.TODO(), options) + }, + } + + // Set up a shared informer + informer := cache.NewSharedIndexInformer( + listerWatcher, + &configv1.Infrastructure{}, + time.Minute*10, + cache.Indexers{}, + ) + + controller := &EBSVolumeTagController{ + configClient: configClient, + coreClient: coreClient, + queue: queue, + informer: informer, + awsEC2Client: awsEC2Client, + } + + // Add event handlers to the informer + _, err = informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + controller.handleAdd(obj) + }, + UpdateFunc: func(oldObj, newObj interface{}) { + controller.handleUpdate(oldObj, newObj) + }, + DeleteFunc: func(obj interface{}) { + controller.handleDelete(obj) + }, + }) + if err != nil { + return nil, err + } + + return controller, nil +} + +// getEC2Client retrieves AWS credentials from the secret and creates an AWS EC2 client +func getEC2Client(ctx context.Context, coreClient corev1.CoreV1Interface, awsRegion string) (*ec2.EC2, error) { + // Fetch the secret containing AWS credentials + secret, err := coreClient.Secrets(awsSecretNamespace).Get(ctx, awsSecretName, metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("error retrieving AWS credentials secret: %v", err) + } + + awsAccessKeyID := secret.Data["aws_access_key_id"] + awsSecretAccessKey := secret.Data["aws_secret_access_key"] + + // Create a new AWS session using the credentials + awsSession, err := session.NewSession(&aws.Config{ + Region: aws.String(awsRegion), + Credentials: credentials.NewStaticCredentials(string(awsAccessKeyID), string(awsSecretAccessKey), ""), + }) + if err != nil { + return nil, fmt.Errorf("error creating AWS session: %v", err) + } + + // Return an EC2 client + return ec2.New(awsSession), nil +} + +// getAWSRegionFromInfrastructure retrieves the AWS region from the Infrastructure resource in OpenShift +func getAWSRegionFromInfrastructure(configClient configclient.Interface) (string, error) { + infra, err := configClient.ConfigV1().Infrastructures().Get(context.TODO(), infrastructureResource, metav1.GetOptions{}) + if err != nil { + return "", fmt.Errorf("failed to retrieve Infrastructure resource: %v", err) + } + + if infra.Status.PlatformStatus == nil || infra.Status.PlatformStatus.AWS == nil { + return "", fmt.Errorf("AWS platform status not found in Infrastructure resource") + } + + return infra.Status.PlatformStatus.AWS.Region, nil +} + +// handleAdd is called when an Infrastructure resource is added +func (c *EBSVolumeTagController) handleAdd(obj interface{}) { + infra := obj.(*configv1.Infrastructure) + klog.Infof("Infrastructure resource added: %s", infra.Name) + c.processInfrastructure(infra) +} + +// handleUpdate is called when an Infrastructure resource is updated +func (c *EBSVolumeTagController) handleUpdate(oldObj, newObj interface{}) { + oldInfra := oldObj.(*configv1.Infrastructure) + newInfra := newObj.(*configv1.Infrastructure) + + klog.Infof("Infrastructure resource updated: %s", newInfra.Name) + + if !reflect.DeepEqual(oldInfra.Status.PlatformStatus.AWS.ResourceTags, newInfra.Status.PlatformStatus.AWS.ResourceTags) { + klog.Infof("AWS ResourceTags changed: triggering processing") + c.processInfrastructure(newInfra) + } +} + +// handleDelete is called when an Infrastructure resource is deleted +func (c *EBSVolumeTagController) handleDelete(obj interface{}) { + infra := obj.(*configv1.Infrastructure) + klog.Infof("Infrastructure resource deleted: %s", infra.Name) +} + +// processInfrastructure processes the Infrastructure resource and updates EBS tags +func (c *EBSVolumeTagController) processInfrastructure(infra *configv1.Infrastructure) { + if infra.Status.PlatformStatus != nil && infra.Status.PlatformStatus.AWS != nil { + awsInfra := infra.Status.PlatformStatus.AWS + err := c.fetchPVsAndUpdateTags(awsInfra.ResourceTags) + if err != nil { + klog.Errorf("Error processing PVs for infrastructure update: %v", err) + } + } +} + +// fetchPVsAndUpdateTags retrieves all PVs and updates the AWS EBS tags +func (c *EBSVolumeTagController) fetchPVsAndUpdateTags(resourceTags []configv1.AWSResourceTag) error { + pvs, err := c.coreClient.PersistentVolumes().List(context.TODO(), metav1.ListOptions{}) + if err != nil { + return fmt.Errorf("error fetching PVs: %v", err) + } + + for _, pv := range pvs.Items { + if pv.Spec.CSI != nil && pv.Spec.CSI.Driver == driverName { + volumeID := pv.Spec.CSI.VolumeHandle + err = c.updateEBSTags(volumeID, resourceTags) + if err != nil { + klog.Errorf("Error updating tags for volume %s: %v", volumeID, err) + } else { + klog.Infof("Successfully updated tags for volume %s", volumeID) + } + } + } + + return nil +} + +// updateEBSTags updates the tags of an AWS EBS volume +func (c *EBSVolumeTagController) updateEBSTags(volumeID string, resourceTags []configv1.AWSResourceTag) error { + existingTagsOutput, err := c.awsEC2Client.DescribeTags(&ec2.DescribeTagsInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("resource-id"), + Values: []*string{aws.String(volumeID)}, + }, + }, + }) + if err != nil { + return err + } + + mergedTags := mergeTags(existingTagsOutput.Tags, resourceTags) + + klog.Infof("Updating EBS tags for volume ID %s with tags: %v", volumeID, mergedTags) + + _, err = c.awsEC2Client.CreateTags(&ec2.CreateTagsInput{ + Resources: []*string{aws.String(volumeID)}, + Tags: mergedTags, + }) + + return err +} + +// mergeTags merges existing AWS tags with new resource tags from OpenShift infrastructure +func mergeTags(existingTags []*ec2.TagDescription, resourceTags []configv1.AWSResourceTag) []*ec2.Tag { + tagMap := make(map[string]string) + + // Add existing tags to the map + for _, tagDesc := range existingTags { + tagMap[*tagDesc.Key] = *tagDesc.Value + } + + // Override with new resource tags + for _, tag := range resourceTags { + tagMap[tag.Key] = tag.Value + } + + // Convert map back to slice of ec2.Tag + var mergedTags []*ec2.Tag + for key, value := range tagMap { + mergedTags = append(mergedTags, &ec2.Tag{ + Key: aws.String(key), + Value: aws.String(value), + }) + } + + return mergedTags +} + +// Run starts the controller and processes events from the informer +func (c *EBSVolumeTagController) Run(ctx context.Context) { + defer c.queue.ShutDown() + + klog.Infof("Starting EBSVolumeTagController") + go c.informer.Run(ctx.Done()) + + if !cache.WaitForCacheSync(ctx.Done(), c.informer.HasSynced) { + klog.Fatal("Failed to sync caches") + return + } + + <-ctx.Done() + + klog.Infof("Shutting down EBSVolumeTagController") +} diff --git a/pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go b/pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go new file mode 100644 index 000000000..b56203737 --- /dev/null +++ b/pkg/driver/aws-ebs/aws_ebs_tags_controller_test.go @@ -0,0 +1,146 @@ +package aws_ebs + +import ( + "context" + fakeconfig "github.com/openshift/client-go/config/clientset/versioned/fake" + "reflect" + "sort" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + configv1 "github.com/openshift/api/config/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +func TestGetAWSRegionFromInfrastructure(t *testing.T) { + infra := &configv1.Infrastructure{ + ObjectMeta: metav1.ObjectMeta{ + Name: "cluster", + }, + Status: configv1.InfrastructureStatus{ + PlatformStatus: &configv1.PlatformStatus{ + AWS: &configv1.AWSPlatformStatus{ + Region: "us-east-1", + }, + }, + }, + } + + // Use the OpenShift fake clientset, not Kubernetes fake client + fakeConfigClient := fakeconfig.NewSimpleClientset() + + // Add the infrastructure resource to the fake OpenShift client + _, err := fakeConfigClient.ConfigV1().Infrastructures().Create(context.TODO(), infra, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("failed to create infrastructure resource: %v", err) + } + + // Call the function to test if the region is correctly retrieved + region, err := getAWSRegionFromInfrastructure(fakeConfigClient) + if err != nil { + t.Fatalf("unexpected error retrieving AWS region: %v", err) + } + + expectedRegion := "us-east-1" + if region != expectedRegion { + t.Errorf("expected AWS region %s, got %s", expectedRegion, region) + } +} + +// TestGetEC2Client tests if the EC2 client is correctly initialized using credentials from a Kubernetes secret. +func TestGetEC2Client(t *testing.T) { + // Create a fake Kubernetes core client with a secret containing AWS credentials + fakeCoreClient := fake.NewSimpleClientset().CoreV1() + + // Add a secret containing AWS credentials + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: awsSecretName, + Namespace: awsSecretNamespace, + }, + Data: map[string][]byte{ + "aws_access_key_id": []byte("fake-access-key-id"), + "aws_secret_access_key": []byte("fake-secret-access-key"), + }, + } + _, err := fakeCoreClient.Secrets(awsSecretNamespace).Create(context.TODO(), secret, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("failed to create secret: %v", err) + } + + // Call getEC2Client and verify that the AWS session is created + awsRegion := "us-east-1" + client, err := getEC2Client(context.TODO(), fakeCoreClient, awsRegion) + if err != nil { + t.Fatalf("unexpected error creating EC2 client: %v", err) + } + + // Check if the EC2 client is not nil (indicating successful creation) + if client == nil { + t.Errorf("expected non-nil EC2 client, got nil") + } +} + +// TestUpdateEBSTags tests the logic of merging and updating AWS tags without making real AWS calls. +func TestUpdateEBSTags(t *testing.T) { + // Define existing tags returned by AWS EC2 + existingTagsOutput := &ec2.DescribeTagsOutput{ + Tags: []*ec2.TagDescription{ + { + Key: aws.String("existing-key"), + Value: aws.String("existing-value"), + }, + { + Key: aws.String("unchanged-key"), + Value: aws.String("unchanged-value"), + }, + }, + } + + // Define new resource tags from the infrastructure resource + resourceTags := []configv1.AWSResourceTag{ + { + Key: "new-key", + Value: "new-value", + }, + { + Key: "existing-key", // This will override the existing tag + Value: "updated-value", + }, + } + + // Call the internal function that merges and updates the tags + mergedTags := mergeTags(existingTagsOutput.Tags, resourceTags) + + // Define the expected merged tags + expectedTags := []*ec2.Tag{ + { + Key: aws.String("existing-key"), + Value: aws.String("updated-value"), // Should be updated + }, + { + Key: aws.String("unchanged-key"), + Value: aws.String("unchanged-value"), + }, + { + Key: aws.String("new-key"), + Value: aws.String("new-value"), + }, + } + sortTags(expectedTags) + sortTags(mergedTags) + // Compare the merged tags with the expected tags + if !reflect.DeepEqual(mergedTags, expectedTags) { + t.Errorf("expected tags %v, got %v", expectedTags, mergedTags) + } +} + +// Helper function to sort tags by Key +func sortTags(tags []*ec2.Tag) { + sort.Slice(tags, func(i, j int) bool { + return *tags[i].Key < *tags[j].Key + }) +}