diff --git a/README.md b/README.md index 3b85de5c..1f63695e 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,75 @@ When using the ASG Lifecycle Hooks, ASG first sends the lifecycle action notific #### Queue Processor with Instance State Change Events When using the EC2 Console or EC2 API to terminate the instance, a state-change notification is sent and the instance termination is started. EC2 does not wait for a "continue" signal before beginning to terminate the instance. When you terminate an EC2 instance, it should trigger a graceful operating system shutdown which will send a SIGTERM to the kubelet, which will in-turn start shutting down pods by propagating that SIGTERM to the containers on the node. If the containers do not shut down by the kubelet's `podTerminationGracePeriod (k8s default is 30s)`, then it will send a SIGKILL to forcefully terminate the containers. Setting the `podTerminationGracePeriod` to a max of 90sec (probably a bit less than that) will delay the termination of pods, which helps in graceful shutdown. +#### Issuing Lifecycle Heartbeats + +You can set NTH to send heartbeats to ASG in Queue Processor mode. This allows for a much longer grace period (up to 48 hours) for termination than the maximum heartbeat timeout of two hours. The feature is useful when pods require long time to drain or when you need a shorter heartbeat timeout with a longer grace period. + +##### How it works + +- When NTH receives an ASG lifecycle termination event, it starts sending heartbeats to ASG to renew the heartbeat timeout associated with the ASG's termination lifecycle hook. +- The heartbeat timeout acts as a timer that starts when the termination event begins. +- Before the timeout reaches zero, the termination process is halted at the `Terminating:Wait` stage. +- By issuing heartbeats, graceful termination duration can be extended up to 48 hours, limited by the global timeout. + +##### How to use + +- Configure a termination lifecycle hook on ASG (required). Set the heartbeat timeout value to be longer than the `Heartbeat Interval`. Each heartbeat signal resets this timeout, extending the duration that an instance remains in the `Terminating:Wait` state. Without this lifecycle hook, the instance will terminate immediately when termination event occurs. +- Configure `Heartbeat Interval` (required) and `Heartbeat Until` (optional). NTH operates normally without heartbeats if neither value is set. If only the interval is specified, `Heartbeat Until` defaults to 172800 seconds (48 hours) and heartbeats will be sent. `Heartbeat Until` must be provided with a valid `Heartbeat Interval`, otherwise NTH will fail to start. Any invalid values (wrong type or out of range) will also prevent NTH from starting. + +##### Configurations +###### `Heartbeat Interval` (Required) +- Time period between consecutive heartbeat signals (in seconds) +- Specifying this value triggers heartbeat +- Range: 30 to 3600 seconds (30 seconds to 1 hour) +- Flag for custom resource definition by *.yaml / helm: `heartbeatInterval` +- CLI flag: `heartbeat-interval` +- Default value: X + +###### `Heartbeat Until` (Optional) +- Duration over which heartbeat signals are sent (in seconds) +- Must be provided with a valid `Heartbeat Interval` +- Range: 60 to 172800 seconds (1 minute to 48 hours) +- Flag for custom resource definition by *.yaml / helm: `heartbeatUntil` +- CLI flag: `heartbeat-until` +- Default value: 172800 (48 hours) + +###### Example Case + +- `Heartbeat Interval`: 1000 seconds +- `Heartbeat Until`: 4500 seconds +- `Heartbeat Timeout`: 3000 seconds + +| Time (s) | Event | Heartbeat Timeout (HT) | Heartbeat Until (HU) | Action | +|----------|-------------|------------------|----------------------|--------| +| 0 | Start | 3000 | 4500 | Termination Event Received | +| 1000 | HB1 Issued | 2000 -> 3000 | 3500 | Send Heartbeat | +| 2000 | HB2 Issued | 2000 -> 3000 | 2500 | Send Heartbeat | +| 3000 | HB3 Issued | 2000 -> 3000 | 1500 | Send Heartbeat | +| 4000 | HB4 Issued | 2000 -> 3000 | 500 | Send Heartbeat | +| 4500 | HB Expires | 2500 | 0 | Stop Heartbeats | +| 7000 | Termination | - | - | Instance Terminates | + +Note: The instance can terminate earlier if its pods finish draining and are ready for termination. + +##### Example Helm Command + +```sh +helm upgrade --install aws-node-termination-handler \ + --namespace kube-system \ + --set enableSqsTerminationDraining=true \ + --set heartbeatInterval=1000 \ + --set heartbeatUntil=4500 \ + // other inputs.. +``` + +##### Important Notes + +- Be aware of global timeout. Instances cannot remain in a wait state indefinitely. The global timeout is 48 hours or 100 times the heartbeat timeout, whichever is smaller. This is the maximum amount of time that you can keep an instance in `terminating:wait` state. +- Lifecycle heartbeats are only supported in Queue Processor mode. Setting `enableSqsTerminationDraining=false` and specifying heartbeat flags is prevented in Helm. Directly editing deployment settings to bypass this will cause NTH to fail. +- The heartbeat interval should be sufficiently shorter than the heartbeat timeout. There's a time gap between instance startup and NTH initialization. Setting the interval just slightly smaller than or equal to the timeout causes the heartbeat timeout to expire before the first heartbeat is issued. Provide adequate buffer time for NTH to complete initialization. +- Issuing heartbeats is part of the termination process. The maximum number of instances that NTH can handle termination concurrently is limited by the number of workers. This implies that heartbeats can only be issued for up to the number of instances specified by the `workers` flag simultaneously. + ### Which one should I use? | Feature | IMDS Processor | Queue Processor | | :-------------------------------------------: | :------------: | :-------------: | @@ -91,6 +160,7 @@ When using the EC2 Console or EC2 API to terminate the instance, a state-change | ASG Termination Lifecycle State Change | ✅ | ❌ | | AZ Rebalance Recommendation | ❌ | ✅ | | Instance State Change Events | ❌ | ✅ | +| Issue Lifecycle Heartbeats | ❌ | ✅ | ### Kubernetes Compatibility @@ -626,5 +696,4 @@ In IMDS mode, metrics can be collected as follows: Contributions are welcome! Please read our [guidelines](https://github.com/aws/aws-node-termination-handler/blob/main/CONTRIBUTING.md) and our [Code of Conduct](https://github.com/aws/aws-node-termination-handler/blob/main/CODE_OF_CONDUCT.md) ## License -This project is licensed under the Apache-2.0 License. - +This project is licensed under the Apache-2.0 License. \ No newline at end of file diff --git a/config/helm/aws-node-termination-handler/templates/deployment.yaml b/config/helm/aws-node-termination-handler/templates/deployment.yaml index 2d7a8896..7c043fec 100644 --- a/config/helm/aws-node-termination-handler/templates/deployment.yaml +++ b/config/helm/aws-node-termination-handler/templates/deployment.yaml @@ -168,6 +168,10 @@ spec: value: {{ .Values.deleteSqsMsgIfNodeNotFound | quote }} - name: WORKERS value: {{ .Values.workers | quote }} + - name: HEARTBEAT_INTERVAL + value: {{ .Values.heartbeatInterval | quote }} + - name: HEARTBEAT_UNTIL + value: {{ .Values.heartbeatUntil | quote }} {{- with .Values.extraEnv }} {{- toYaml . | nindent 12 }} {{- end }} diff --git a/pkg/config/config.go b/pkg/config/config.go index 05fabdca..6e926bf4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -112,6 +112,9 @@ const ( queueURLConfigKey = "QUEUE_URL" completeLifecycleActionDelaySecondsKey = "COMPLETE_LIFECYCLE_ACTION_DELAY_SECONDS" deleteSqsMsgIfNodeNotFoundKey = "DELETE_SQS_MSG_IF_NODE_NOT_FOUND" + // heartbeat + heartbeatIntervalKey = "HEARTBEAT_INTERVAL" + heartbeatUntilKey = "HEARTBEAT_UNTIL" ) // Config arguments set via CLI, environment variables, or defaults @@ -166,6 +169,8 @@ type Config struct { CompleteLifecycleActionDelaySeconds int DeleteSqsMsgIfNodeNotFound bool UseAPIServerCacheToListPods bool + HeartbeatInterval int + HeartbeatUntil int } // ParseCliArgs parses cli arguments and uses environment variables as fallback values @@ -230,6 +235,8 @@ func ParseCliArgs() (config Config, err error) { flag.IntVar(&config.CompleteLifecycleActionDelaySeconds, "complete-lifecycle-action-delay-seconds", getIntEnv(completeLifecycleActionDelaySecondsKey, -1), "Delay completing the Autoscaling lifecycle action after a node has been drained.") flag.BoolVar(&config.DeleteSqsMsgIfNodeNotFound, "delete-sqs-msg-if-node-not-found", getBoolEnv(deleteSqsMsgIfNodeNotFoundKey, false), "If true, delete SQS Messages from the SQS Queue if the targeted node(s) are not found.") flag.BoolVar(&config.UseAPIServerCacheToListPods, "use-apiserver-cache", getBoolEnv(useAPIServerCache, false), "If true, leverage the k8s apiserver's index on pod's spec.nodeName to list pods on a node, instead of doing an etcd quorum read.") + flag.IntVar(&config.HeartbeatInterval, "heartbeat-interval", getIntEnv(heartbeatIntervalKey, -1), "The time period in seconds between consecutive heartbeat signals. Valid range: 30-3600 seconds (30 seconds to 1 hour).") + flag.IntVar(&config.HeartbeatUntil, "heartbeat-until", getIntEnv(heartbeatUntilKey, -1), "The duration in seconds over which heartbeat signals are sent. Valid range: 60-172800 seconds (1 minute to 48 hours).") flag.Parse() if isConfigProvided("pod-termination-grace-period", podTerminationGracePeriodConfigKey) && isConfigProvided("grace-period", gracePeriodConfigKey) { @@ -274,6 +281,27 @@ func ParseCliArgs() (config Config, err error) { panic("You must provide a node-name to the CLI or NODE_NAME environment variable.") } + // heartbeat value boundary and compability check + if !config.EnableSQSTerminationDraining && (config.HeartbeatInterval != -1 || config.HeartbeatUntil != -1) { + return config, fmt.Errorf("currently using IMDS mode. Heartbeat is only supported for Queue Processor mode") + } + if config.HeartbeatInterval != -1 && (config.HeartbeatInterval < 30 || config.HeartbeatInterval > 3600) { + return config, fmt.Errorf("invalid heartbeat-interval passed: %d Should be between 30 and 3600 seconds", config.HeartbeatInterval) + } + if config.HeartbeatUntil != -1 && (config.HeartbeatUntil < 60 || config.HeartbeatUntil > 172800) { + return config, fmt.Errorf("invalid heartbeat-until passed: %d Should be between 60 and 172800 seconds", config.HeartbeatUntil) + } + if config.HeartbeatInterval == -1 && config.HeartbeatUntil != -1 { + return config, fmt.Errorf("invalid heartbeat configuration: heartbeat-interval is required when heartbeat-until is set") + } + if config.HeartbeatInterval != -1 && config.HeartbeatUntil == -1 { + config.HeartbeatUntil = 172800 + log.Info().Msgf("Since heartbeat-until is not set, defaulting to %d seconds", config.HeartbeatUntil) + } + if config.HeartbeatInterval != -1 && config.HeartbeatUntil != -1 && config.HeartbeatInterval > config.HeartbeatUntil { + return config, fmt.Errorf("invalid heartbeat configuration: heartbeat-interval should be less than or equal to heartbeat-until") + } + // client-go expects these to be set in env vars os.Setenv(kubernetesServiceHostConfigKey, config.KubernetesServiceHost) os.Setenv(kubernetesServicePortConfigKey, config.KubernetesServicePort) @@ -332,6 +360,8 @@ func (c Config) PrintJsonConfigArgs() { Str("ManagedTag", c.ManagedTag). Bool("use_provider_id", c.UseProviderId). Bool("use_apiserver_cache", c.UseAPIServerCacheToListPods). + Int("heartbeat_interval", c.HeartbeatInterval). + Int("heartbeat_until", c.HeartbeatUntil). Msg("aws-node-termination-handler arguments") } @@ -383,7 +413,9 @@ func (c Config) PrintHumanConfigArgs() { "\tmanaged-tag: %s,\n"+ "\tuse-provider-id: %t,\n"+ "\taws-endpoint: %s,\n"+ - "\tuse-apiserver-cache: %t,\n", + "\tuse-apiserver-cache: %t,\n"+ + "\theartbeat-interval: %d,\n"+ + "\theartbeat-until: %d\n", c.DryRun, c.NodeName, c.PodName, @@ -424,6 +456,8 @@ func (c Config) PrintHumanConfigArgs() { c.UseProviderId, c.AWSEndpoint, c.UseAPIServerCacheToListPods, + c.HeartbeatInterval, + c.HeartbeatUntil, ) } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index c411b9fd..8b7e2399 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -37,7 +37,7 @@ func TestParseCliArgsEnvSuccess(t *testing.T) { t.Setenv("ENABLE_SCHEDULED_EVENT_DRAINING", "true") t.Setenv("ENABLE_SPOT_INTERRUPTION_DRAINING", "false") t.Setenv("ENABLE_ASG_LIFECYCLE_DRAINING", "false") - t.Setenv("ENABLE_SQS_TERMINATION_DRAINING", "false") + t.Setenv("ENABLE_SQS_TERMINATION_DRAINING", "true") t.Setenv("ENABLE_REBALANCE_MONITORING", "true") t.Setenv("ENABLE_REBALANCE_DRAINING", "true") t.Setenv("GRACE_PERIOD", "12345") @@ -54,6 +54,8 @@ func TestParseCliArgsEnvSuccess(t *testing.T) { t.Setenv("METADATA_TRIES", "100") t.Setenv("CORDON_ONLY", "false") t.Setenv("USE_APISERVER_CACHE", "true") + t.Setenv("HEARTBEAT_INTERVAL", "30") + t.Setenv("HEARTBEAT_UNTIL", "60") nthConfig, err := config.ParseCliArgs() h.Ok(t, err) @@ -64,7 +66,7 @@ func TestParseCliArgsEnvSuccess(t *testing.T) { h.Equals(t, true, nthConfig.EnableScheduledEventDraining) h.Equals(t, false, nthConfig.EnableSpotInterruptionDraining) h.Equals(t, false, nthConfig.EnableASGLifecycleDraining) - h.Equals(t, false, nthConfig.EnableSQSTerminationDraining) + h.Equals(t, true, nthConfig.EnableSQSTerminationDraining) h.Equals(t, true, nthConfig.EnableRebalanceMonitoring) h.Equals(t, true, nthConfig.EnableRebalanceDraining) h.Equals(t, false, nthConfig.IgnoreDaemonSets) @@ -80,6 +82,8 @@ func TestParseCliArgsEnvSuccess(t *testing.T) { h.Equals(t, 100, nthConfig.MetadataTries) h.Equals(t, false, nthConfig.CordonOnly) h.Equals(t, true, nthConfig.UseAPIServerCacheToListPods) + h.Equals(t, 30, nthConfig.HeartbeatInterval) + h.Equals(t, 60, nthConfig.HeartbeatUntil) // Check that env vars were set value, ok := os.LookupEnv("KUBERNETES_SERVICE_HOST") @@ -101,7 +105,7 @@ func TestParseCliArgsSuccess(t *testing.T) { "--enable-scheduled-event-draining=true", "--enable-spot-interruption-draining=false", "--enable-asg-lifecycle-draining=false", - "--enable-sqs-termination-draining=false", + "--enable-sqs-termination-draining=true", "--enable-rebalance-monitoring=true", "--enable-rebalance-draining=true", "--ignore-daemon-sets=false", @@ -117,6 +121,8 @@ func TestParseCliArgsSuccess(t *testing.T) { "--metadata-tries=100", "--cordon-only=false", "--use-apiserver-cache=true", + "--heartbeat-interval=30", + "--heartbeat-until=60", } nthConfig, err := config.ParseCliArgs() h.Ok(t, err) @@ -128,7 +134,7 @@ func TestParseCliArgsSuccess(t *testing.T) { h.Equals(t, true, nthConfig.EnableScheduledEventDraining) h.Equals(t, false, nthConfig.EnableSpotInterruptionDraining) h.Equals(t, false, nthConfig.EnableASGLifecycleDraining) - h.Equals(t, false, nthConfig.EnableSQSTerminationDraining) + h.Equals(t, true, nthConfig.EnableSQSTerminationDraining) h.Equals(t, true, nthConfig.EnableRebalanceMonitoring) h.Equals(t, true, nthConfig.EnableRebalanceDraining) h.Equals(t, false, nthConfig.IgnoreDaemonSets) @@ -145,6 +151,8 @@ func TestParseCliArgsSuccess(t *testing.T) { h.Equals(t, false, nthConfig.CordonOnly) h.Equals(t, false, nthConfig.EnablePrometheus) h.Equals(t, true, nthConfig.UseAPIServerCacheToListPods) + h.Equals(t, 30, nthConfig.HeartbeatInterval) + h.Equals(t, 60, nthConfig.HeartbeatUntil) // Check that env vars were set value, ok := os.LookupEnv("KUBERNETES_SERVICE_HOST") @@ -176,6 +184,9 @@ func TestParseCliArgsOverrides(t *testing.T) { t.Setenv("WEBHOOK_TEMPLATE", "no") t.Setenv("METADATA_TRIES", "100") t.Setenv("CORDON_ONLY", "true") + t.Setenv("HEARTBEAT_INTERVAL", "3601") + t.Setenv("HEARTBEAT_UNTIL", "172801") + os.Args = []string{ "cmd", "--use-provider-id=false", @@ -201,6 +212,8 @@ func TestParseCliArgsOverrides(t *testing.T) { "--cordon-only=false", "--enable-prometheus-server=true", "--prometheus-server-port=2112", + "--heartbeat-interval=3600", + "--heartbeat-until=172800", } nthConfig, err := config.ParseCliArgs() h.Ok(t, err) @@ -229,6 +242,8 @@ func TestParseCliArgsOverrides(t *testing.T) { h.Equals(t, false, nthConfig.CordonOnly) h.Equals(t, true, nthConfig.EnablePrometheus) h.Equals(t, 2112, nthConfig.PrometheusPort) + h.Equals(t, 3600, nthConfig.HeartbeatInterval) + h.Equals(t, 172800, nthConfig.HeartbeatUntil) // Check that env vars were set value, ok := os.LookupEnv("KUBERNETES_SERVICE_HOST") diff --git a/pkg/monitor/sqsevent/asg-lifecycle-event.go b/pkg/monitor/sqsevent/asg-lifecycle-event.go index fc034931..a442b824 100644 --- a/pkg/monitor/sqsevent/asg-lifecycle-event.go +++ b/pkg/monitor/sqsevent/asg-lifecycle-event.go @@ -15,11 +15,14 @@ package sqsevent import ( "encoding/json" + "errors" "fmt" + "time" "github.com/aws/aws-node-termination-handler/pkg/monitor" "github.com/aws/aws-node-termination-handler/pkg/node" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/sqs" "github.com/rs/zerolog/log" @@ -95,19 +98,31 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m Description: fmt.Sprintf("ASG Lifecycle Termination event received. Instance will be interrupted at %s \n", event.getTime()), } + stopHeartbeatCh := make(chan struct{}) + interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, _ node.Node) error { + _, err = m.continueLifecycleAction(lifecycleDetail) if err != nil { return fmt.Errorf("continuing ASG termination lifecycle: %w", err) } log.Info().Str("lifecycleHookName", lifecycleDetail.LifecycleHookName).Str("instanceID", lifecycleDetail.EC2InstanceID).Msg("Completed ASG Lifecycle Hook") + + close(stopHeartbeatCh) return m.deleteMessage(message) } interruptionEvent.PreDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { + nthConfig := n.GetNthConfig() + // If only HeartbeatInterval is set, HeartbeatUntil will default to 172800. + if nthConfig.HeartbeatInterval != -1 && nthConfig.HeartbeatUntil != -1 { + go m.checkHeartbeatTimeout(nthConfig.HeartbeatInterval, lifecycleDetail) + go m.SendHeartbeats(nthConfig.HeartbeatInterval, nthConfig.HeartbeatUntil, lifecycleDetail, stopHeartbeatCh) + } + err := n.TaintASGLifecycleTermination(interruptionEvent.NodeName, interruptionEvent.EventID) if err != nil { - log.Err(err).Msgf("Unable to taint node with taint %s:%s", node.ASGLifecycleTerminationTaint, interruptionEvent.EventID) + log.Err(err).Msgf("unable to taint node with taint %s:%s", node.ASGLifecycleTerminationTaint, interruptionEvent.EventID) } return nil } @@ -115,6 +130,108 @@ func (m SQSMonitor) asgTerminationToInterruptionEvent(event *EventBridgeEvent, m return &interruptionEvent, nil } +// Compare the heartbeatInterval with the heartbeat timeout and warn if (heartbeatInterval >= heartbeat timeout) +func (m SQSMonitor) checkHeartbeatTimeout(heartbeatInterval int, lifecycleDetail *LifecycleDetail) { + input := &autoscaling.DescribeLifecycleHooksInput{ + AutoScalingGroupName: aws.String(lifecycleDetail.AutoScalingGroupName), + LifecycleHookNames: []*string{aws.String(lifecycleDetail.LifecycleHookName)}, + } + + lifecyclehook, err := m.ASG.DescribeLifecycleHooks(input) + if err != nil { + log.Err(err).Msg("failed to describe lifecycle hook") + return + } + + if len(lifecyclehook.LifecycleHooks) == 0 { + log.Warn(). + Str("asgName", lifecycleDetail.AutoScalingGroupName). + Str("lifecycleHookName", lifecycleDetail.LifecycleHookName). + Msg("Tried to check heartbeat timeout, but no lifecycle hook found from ASG") + return + } + + heartbeatTimeout := int(*lifecyclehook.LifecycleHooks[0].HeartbeatTimeout) + + if heartbeatInterval >= heartbeatTimeout { + log.Warn().Msgf( + "Heartbeat interval (%d seconds) is equal to or greater than "+ + "the heartbeat timeout (%d seconds) for the lifecycle hook %s attached to ASG %s. "+ + "The node would likely be terminated before the heartbeat is sent", + heartbeatInterval, + heartbeatTimeout, + *lifecyclehook.LifecycleHooks[0].LifecycleHookName, + *lifecyclehook.LifecycleHooks[0].AutoScalingGroupName, + ) + } +} + +// Issue lifecycle heartbeats to reset the heartbeat timeout timer in ASG +func (m SQSMonitor) SendHeartbeats(heartbeatInterval int, heartbeatUntil int, lifecycleDetail *LifecycleDetail, stopCh <-chan struct{}) { + ticker := time.NewTicker(time.Duration(heartbeatInterval) * time.Second) + defer ticker.Stop() + timeout := time.After(time.Duration(heartbeatUntil) * time.Second) + + for { + select { + case <-stopCh: + log.Info().Str("asgName", lifecycleDetail.AutoScalingGroupName). + Str("lifecycleHookName", lifecycleDetail.LifecycleHookName). + Str("lifecycleActionToken", lifecycleDetail.LifecycleActionToken). + Str("instanceID", lifecycleDetail.EC2InstanceID). + Msg("Successfully cordoned and drained the node, stopping heartbeat") + return + case <-ticker.C: + err := m.recordLifecycleActionHeartbeat(lifecycleDetail) + if err != nil { + log.Err(err).Msg("invalid heartbeat target, stopping heartbeat") + return + } + case <-timeout: + log.Info().Str("asgName", lifecycleDetail.AutoScalingGroupName). + Str("lifecycleHookName", lifecycleDetail.LifecycleHookName). + Str("lifecycleActionToken", lifecycleDetail.LifecycleActionToken). + Str("instanceID", lifecycleDetail.EC2InstanceID). + Msg("Heartbeat deadline exceeded, stopping heartbeat") + return + } + } +} + +func (m SQSMonitor) recordLifecycleActionHeartbeat(lifecycleDetail *LifecycleDetail) error { + input := &autoscaling.RecordLifecycleActionHeartbeatInput{ + AutoScalingGroupName: aws.String(lifecycleDetail.AutoScalingGroupName), + LifecycleHookName: aws.String(lifecycleDetail.LifecycleHookName), + LifecycleActionToken: aws.String(lifecycleDetail.LifecycleActionToken), + InstanceId: aws.String(lifecycleDetail.EC2InstanceID), + } + + // Stop the heartbeat if the target is invalid + _, err := m.ASG.RecordLifecycleActionHeartbeat(input) + if err != nil { + var awsErr awserr.Error + log.Warn(). + Str("asgName", lifecycleDetail.AutoScalingGroupName). + Str("lifecycleHookName", lifecycleDetail.LifecycleHookName). + Str("lifecycleActionToken", lifecycleDetail.LifecycleActionToken). + Str("instanceID", lifecycleDetail.EC2InstanceID). + Err(err). + Msg("Failed to send lifecycle heartbeat") + if errors.As(err, &awsErr) && awsErr.Code() == "ValidationError" { + return err + } + return nil + } + + log.Info(). + Str("asgName", lifecycleDetail.AutoScalingGroupName). + Str("lifecycleHookName", lifecycleDetail.LifecycleHookName). + Str("lifecycleActionToken", lifecycleDetail.LifecycleActionToken). + Str("instanceID", lifecycleDetail.EC2InstanceID). + Msg("Successfully sent lifecycle heartbeat") + return nil +} + func (m SQSMonitor) deleteMessage(message *sqs.Message) error { errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { @@ -123,7 +240,7 @@ func (m SQSMonitor) deleteMessage(message *sqs.Message) error { return nil } -// Continues the lifecycle hook thereby indicating a successful action occured +// Continues the lifecycle hook thereby indicating a successful action occurred func (m SQSMonitor) continueLifecycleAction(lifecycleDetail *LifecycleDetail) (*autoscaling.CompleteLifecycleActionOutput, error) { return m.completeLifecycleAction(&autoscaling.CompleteLifecycleActionInput{ AutoScalingGroupName: &lifecycleDetail.AutoScalingGroupName, diff --git a/pkg/monitor/sqsevent/sqs-monitor_test.go b/pkg/monitor/sqsevent/sqs-monitor_test.go index 2b93085e..1884dddc 100644 --- a/pkg/monitor/sqsevent/sqs-monitor_test.go +++ b/pkg/monitor/sqsevent/sqs-monitor_test.go @@ -18,7 +18,9 @@ import ( "fmt" "strings" "testing" + "time" + "github.com/aws/aws-node-termination-handler/pkg/config" "github.com/aws/aws-node-termination-handler/pkg/monitor" "github.com/aws/aws-node-termination-handler/pkg/monitor/sqsevent" "github.com/aws/aws-node-termination-handler/pkg/node" @@ -276,7 +278,6 @@ func TestMonitor_AsgDirectToSqsSuccess(t *testing.T) { default: h.Ok(t, fmt.Errorf("Expected an event to be generated")) } - } func TestMonitor_AsgDirectToSqsTestNotification(t *testing.T) { @@ -520,7 +521,6 @@ func TestMonitor_DrainTasksASGFailure(t *testing.T) { default: h.Ok(t, fmt.Errorf("Expected to get an event with a failing post drain task")) } - } func TestMonitor_Failure(t *testing.T) { @@ -908,7 +908,93 @@ func TestMonitor_InstanceNotManaged(t *testing.T) { } } -// AWS Mock Helpers specific to sqs-monitor tests +func TestSendHeartbeats_EarlyClosure(t *testing.T) { + err := heartbeatTestHelper(nil, 3500, 1, 5) + h.Ok(t, err) + h.Assert(t, h.HeartbeatCallCount == 3, "3 Heartbeat Expected, got %d", h.HeartbeatCallCount) +} + +func TestSendHeartbeats_HeartbeatUntilExpire(t *testing.T) { + err := heartbeatTestHelper(nil, 8000, 1, 5) + h.Ok(t, err) + h.Assert(t, h.HeartbeatCallCount == 5, "5 Heartbeat Expected, got %d", h.HeartbeatCallCount) +} + +func TestSendHeartbeats_ErrThrottlingASG(t *testing.T) { + RecordLifecycleActionHeartbeatErr := awserr.New("Throttling", "Rate exceeded", nil) + err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 8000, 1, 6) + h.Ok(t, err) + h.Assert(t, h.HeartbeatCallCount == 6, "6 Heartbeat Expected, got %d", h.HeartbeatCallCount) +} + +func TestSendHeartbeats_ErrInvalidTarget(t *testing.T) { + RecordLifecycleActionHeartbeatErr := awserr.New("ValidationError", "No active Lifecycle Action found", nil) + err := heartbeatTestHelper(RecordLifecycleActionHeartbeatErr, 6000, 1, 4) + h.Ok(t, err) + h.Assert(t, h.HeartbeatCallCount == 1, "1 Heartbeat Expected, got %d", h.HeartbeatCallCount) +} + +func heartbeatTestHelper(RecordLifecycleActionHeartbeatErr error, sleepMilliSeconds int, heartbeatInterval int, heartbeatUntil int) error { + h.HeartbeatCallCount = 0 + + msg, err := getSQSMessageFromEvent(asgLifecycleEvent) + if err != nil { + return err + } + + sqsMock := h.MockedSQS{ + ReceiveMessageResp: sqs.ReceiveMessageOutput{Messages: []*sqs.Message{&msg}}, + } + dnsNodeName := "ip-10-0-0-157.us-east-2.compute.internal" + ec2Mock := h.MockedEC2{ + DescribeInstancesResp: getDescribeInstancesResp(dnsNodeName, true, true), + } + asgMock := h.MockedASG{ + CompleteLifecycleActionResp: autoscaling.CompleteLifecycleActionOutput{}, + RecordLifecycleActionHeartbeatResp: autoscaling.RecordLifecycleActionHeartbeatOutput{}, + RecordLifecycleActionHeartbeatErr: RecordLifecycleActionHeartbeatErr, + HeartbeatTimeout: 30, + } + + drainChan := make(chan monitor.InterruptionEvent, 1) + sqsMonitor := sqsevent.SQSMonitor{ + SQS: sqsMock, + EC2: ec2Mock, + ASG: asgMock, + InterruptionChan: drainChan, + BeforeCompleteLifecycleAction: func() { + time.Sleep(time.Duration(sleepMilliSeconds) * time.Millisecond) + }, + } + + if err := sqsMonitor.Monitor(); err != nil { + return err + } + + nthConfig := &config.Config{ + HeartbeatInterval: heartbeatInterval, + HeartbeatUntil: heartbeatUntil, + } + + testNode, _ := node.New(*nthConfig, nil) + result := <-drainChan + + if result.PreDrainTask == nil { + return fmt.Errorf("PreDrainTask should have been set") + } + if err := result.PreDrainTask(result, *testNode); err != nil { + return err + } + + if result.PostDrainTask == nil { + return fmt.Errorf("PostDrainTask should have been set") + } + if err := result.PostDrainTask(result, *testNode); err != nil { + return err + } + + return nil +} func getDescribeInstancesResp(privateDNSName string, withASGTag bool, withManagedTag bool) ec2.DescribeInstancesOutput { tags := []*ec2.Tag{} diff --git a/pkg/node/node.go b/pkg/node/node.go index 7e323d13..204c5de6 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -280,6 +280,10 @@ func (n Node) MarkForUncordonAfterReboot(nodeName string) error { return nil } +func (n Node) GetNthConfig() config.Config { + return n.nthConfig +} + // addLabel will add a label to the node given a label key and value // Specifying true for the skipExisting parameter will skip adding the label if it already exists func (n Node) addLabel(nodeName string, key string, value string, skipExisting bool) error { diff --git a/pkg/test/aws-mocks.go b/pkg/test/aws-mocks.go index 8d5c8ae5..79626687 100644 --- a/pkg/test/aws-mocks.go +++ b/pkg/test/aws-mocks.go @@ -14,6 +14,7 @@ package test import ( + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" "github.com/aws/aws-sdk-go/service/ec2" @@ -56,12 +57,17 @@ func (m MockedEC2) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.De // MockedASG mocks the autoscaling API type MockedASG struct { autoscalingiface.AutoScalingAPI - CompleteLifecycleActionResp autoscaling.CompleteLifecycleActionOutput - CompleteLifecycleActionErr error - DescribeAutoScalingInstancesResp autoscaling.DescribeAutoScalingInstancesOutput - DescribeAutoScalingInstancesErr error - DescribeTagsPagesResp autoscaling.DescribeTagsOutput - DescribeTagsPagesErr error + CompleteLifecycleActionResp autoscaling.CompleteLifecycleActionOutput + CompleteLifecycleActionErr error + DescribeAutoScalingInstancesResp autoscaling.DescribeAutoScalingInstancesOutput + DescribeAutoScalingInstancesErr error + DescribeTagsPagesResp autoscaling.DescribeTagsOutput + DescribeTagsPagesErr error + RecordLifecycleActionHeartbeatResp autoscaling.RecordLifecycleActionHeartbeatOutput + RecordLifecycleActionHeartbeatErr error + HeartbeatTimeout int + AutoScalingGroupName string + LifecycleHookName string } // CompleteLifecycleAction mocks the autoscaling.CompleteLifecycleAction API call @@ -81,3 +87,26 @@ func (m MockedASG) DescribeTagsPages(input *autoscaling.DescribeTagsInput, fn de fn(&m.DescribeTagsPagesResp, true) return m.DescribeTagsPagesErr } + +var HeartbeatCallCount int + +// RecordLifecycleActionHeartbeat mocks the autoscaling.RecordLifecycleActionHeartbeat API call +func (m MockedASG) RecordLifecycleActionHeartbeat(input *autoscaling.RecordLifecycleActionHeartbeatInput) (*autoscaling.RecordLifecycleActionHeartbeatOutput, error) { + HeartbeatCallCount++ + if m.RecordLifecycleActionHeartbeatErr != nil && HeartbeatCallCount%2 == 1 { + return &m.RecordLifecycleActionHeartbeatResp, m.RecordLifecycleActionHeartbeatErr + } + return &m.RecordLifecycleActionHeartbeatResp, nil +} + +func (m MockedASG) DescribeLifecycleHooks(input *autoscaling.DescribeLifecycleHooksInput) (*autoscaling.DescribeLifecycleHooksOutput, error) { + return &autoscaling.DescribeLifecycleHooksOutput{ + LifecycleHooks: []*autoscaling.LifecycleHook{ + { + AutoScalingGroupName: &m.AutoScalingGroupName, + LifecycleHookName: &m.LifecycleHookName, + HeartbeatTimeout: aws.Int64(int64(m.HeartbeatTimeout)), + }, + }, + }, nil +} diff --git a/test/e2e/asg-lifecycle-sqs-heartbeat-test b/test/e2e/asg-lifecycle-sqs-heartbeat-test new file mode 100755 index 00000000..6b0afc2e --- /dev/null +++ b/test/e2e/asg-lifecycle-sqs-heartbeat-test @@ -0,0 +1,228 @@ +#!/bin/bash +set -euo pipefail + +# Available env vars: +# $TMP_DIR +# $CLUSTER_NAME +# $KUBECONFIG +# $NODE_TERMINATION_HANDLER_DOCKER_REPO +# $NODE_TERMINATION_HANDLER_DOCKER_TAG +# $WEBHOOK_DOCKER_REPO +# $WEBHOOK_DOCKER_TAG +# $AEMM_URL +# $AEMM_VERSION + + +function fail_and_exit { + echo "❌ ASG Lifecycle SQS Heartbeat Test failed $CLUSTER_NAME ❌" + exit "${1:-1}" +} + +echo "Starting ASG Lifecycle SQS Heartbeat Test for Node Termination Handler" +START_TIME=$(date -u +"%Y-%m-%dT%TZ") + +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" + +common_helm_args=() + +localstack_helm_args=( + upgrade + --install + --namespace default + "$CLUSTER_NAME-localstack" + "$SCRIPTPATH/../../config/helm/localstack/" + --set nodeSelector."${NTH_CONTROL_LABEL}" + --set defaultRegion="${AWS_REGION}" + --wait +) + +set -x +helm "${localstack_helm_args[@]}" +set +x + +sleep 10 + +RUN_INSTANCE_CMD="awslocal ec2 run-instances --private-ip-address ${WORKER_IP} --region ${AWS_REGION} --tag-specifications 'ResourceType=instance,Tags=[{Key=aws:autoscaling:groupName,Value=nth-integ-test},{Key=aws-node-termination-handler/managed,Value=blah}]'" +localstack_pod=$(kubectl get pods --selector app=localstack --field-selector="status.phase=Running" \ + -o go-template --template '{{range .items}}{{.metadata.name}} {{.metadata.creationTimestamp}}{{"\n"}}{{end}}' \ + | awk '$2 >= "'"${START_TIME//+0000/Z}"'" { print $1 }') +echo "🥑 Using localstack pod ${localstack_pod}" +run_instances_resp=$(kubectl exec -i "${localstack_pod}" -- bash -c "${RUN_INSTANCE_CMD}") +private_dns_name=$(echo "${run_instances_resp}" | jq -r '.Instances[] .PrivateDnsName') +instance_id=$(echo "${run_instances_resp}" | jq -r '.Instances[] .InstanceId') +echo "🥑 Started mock EC2 instance ($instance_id) w/ private DNS name: ${private_dns_name}" +set -x +CREATE_SQS_CMD="awslocal sqs create-queue --queue-name "${CLUSTER_NAME}-queue" --attributes MessageRetentionPeriod=300 --region ${AWS_REGION}" +queue_url=$(kubectl exec -i "${localstack_pod}" -- bash -c "${CREATE_SQS_CMD}" | jq -r .QueueUrl) +set +x + +echo "🥑 Created SQS Queue ${queue_url}" + +# arguments specific to heartbeat testing +COMPLETE_LIFECYCLE_ACTION_DELAY_SECONDS=120 +HEARTBEAT_INTERVAL=30 +HEARTBEAT_UNTIL=100 + +anth_helm_args=( + upgrade + --install + --namespace kube-system + "$CLUSTER_NAME-acth" + "$SCRIPTPATH/../../config/helm/aws-node-termination-handler/" + --set completeLifecycleActionDelaySeconds="$COMPLETE_LIFECYCLE_ACTION_DELAY_SECONDS" + --set heartbeatInterval="$HEARTBEAT_INTERVAL" + --set heartbeatUntil="$HEARTBEAT_UNTIL" + --set image.repository="$NODE_TERMINATION_HANDLER_DOCKER_REPO" + --set image.tag="$NODE_TERMINATION_HANDLER_DOCKER_TAG" + --set nodeSelector."${NTH_CONTROL_LABEL}" + --set tolerations[0].operator=Exists + --set awsAccessKeyID=foo + --set awsSecretAccessKey=bar + --set awsRegion="${AWS_REGION}" + --set awsEndpoint="http://localstack.default" + --set checkTagBeforeDraining=false + --set enableSqsTerminationDraining=true + --set queueURL="${queue_url}" + --wait +) +[[ -n "${NODE_TERMINATION_HANDLER_DOCKER_PULL_POLICY-}" ]] && + anth_helm_args+=(--set image.pullPolicy="$NODE_TERMINATION_HANDLER_DOCKER_PULL_POLICY") +[[ ${#common_helm_args[@]} -gt 0 ]] && + anth_helm_args+=("${common_helm_args[@]}") + +set -x +helm "${anth_helm_args[@]}" +set +x + +emtp_helm_args=( + upgrade + --install + --namespace default + "$CLUSTER_NAME-emtp" + "$SCRIPTPATH/../../config/helm/webhook-test-proxy/" + --set webhookTestProxy.image.repository="$WEBHOOK_DOCKER_REPO" + --set webhookTestProxy.image.tag="$WEBHOOK_DOCKER_TAG" + --wait +) +[[ -n "${WEBHOOK_DOCKER_PULL_POLICY-}" ]] && + emtp_helm_args+=(--set webhookTestProxy.image.pullPolicy="$WEBHOOK_DOCKER_PULL_POLICY") +[[ ${#common_helm_args[@]} -gt 0 ]] && + emtp_helm_args+=("${common_helm_args[@]}") + +set -x +helm "${emtp_helm_args[@]}" +set +x + +TAINT_CHECK_CYCLES=15 +TAINT_CHECK_SLEEP=15 + +DEPLOYED=0 + +for i in $(seq 1 $TAINT_CHECK_CYCLES); do + if [[ $(kubectl get deployments regular-pod-test -o jsonpath='{.status.unavailableReplicas}') -eq 0 ]]; then + echo "✅ Verified regular-pod-test pod was scheduled and started!" + DEPLOYED=1 + break + fi + echo "Setup Loop $i/$TAINT_CHECK_CYCLES, sleeping for $TAINT_CHECK_SLEEP seconds" + sleep $TAINT_CHECK_SLEEP +done + +if [[ $DEPLOYED -eq 0 ]]; then + echo "❌ regular-pod-test pod deployment failed" + fail_and_exit 2 +fi + +ASG_TERMINATION_EVENT=$(cat < /dev/null; then + echo "✅ Verified the worker node was cordoned!" + cordoned=1 + fi + + if [[ $cordoned -eq 1 && $(kubectl get deployments regular-pod-test -o=jsonpath='{.status.unavailableReplicas}') -eq 1 ]]; then + echo "✅ Verified the regular-pod-test pod was evicted!" + echo "✅ ASG Lifecycle SQS Test Passed with Heartbeat $CLUSTER_NAME! ✅" + exit 0 + fi + echo "Assertion Loop $i/$TAINT_CHECK_CYCLES, sleeping for $TAINT_CHECK_SLEEP seconds" + sleep $TAINT_CHECK_SLEEP +done + +if [[ $cordoned -eq 0 ]]; then + echo "❌ Worker node was not cordoned" +else + echo "❌ regular-pod-test was not evicted" +fi + +fail_and_exit 1