From 4a674523a1fc6dddbca5639b4eec912d7310ce32 Mon Sep 17 00:00:00 2001 From: lucaswzhang Date: Wed, 18 Aug 2021 10:37:34 +0800 Subject: [PATCH] add reconciler.v1 --- pkg/controller.v1/common/job.go | 129 +---- pkg/controller.v1/common/pod.go | 62 +-- pkg/controller.v1/common/pod_test.go | 13 +- pkg/controller.v1/common/service.go | 69 +-- pkg/controller.v1/common/service_test.go | 4 +- pkg/controller.v1/common/status.go | 16 +- pkg/controller.v1/common/util.go | 44 ++ pkg/core/job.go | 115 +++++ pkg/core/pod.go | 77 +++ pkg/core/service.go | 90 ++++ pkg/core/status.go | 27 + pkg/core/utils.go | 19 + pkg/reconciler.v1/common/README.md | 24 + pkg/reconciler.v1/common/gang.go | 34 ++ .../common/gang_scheduler_framework.go | 21 + pkg/reconciler.v1/common/gang_volcano.go | 193 +++++++ pkg/reconciler.v1/common/interface.go | 260 ++++++++++ pkg/reconciler.v1/common/job.go | 478 ++++++++++++++++++ pkg/reconciler.v1/common/pod.go | 276 ++++++++++ pkg/reconciler.v1/common/pod_test.go | 143 ++++++ pkg/reconciler.v1/common/reconciler.go | 147 ++++++ pkg/reconciler.v1/common/service.go | 221 ++++++++ pkg/reconciler.v1/common/service_test.go | 103 ++++ pkg/reconciler.v1/common/utils.go | 66 +++ pkg/reconciler.v1/common/utils_test.go | 65 +++ pkg/util/counter.go | 71 +++ .../reconciler.v1/test_job/dummy_client.go | 60 +++ .../test_job/test_job_reconciler.go | 131 +++++ 28 files changed, 2692 insertions(+), 266 deletions(-) create mode 100644 pkg/core/job.go create mode 100644 pkg/core/pod.go create mode 100644 pkg/core/service.go create mode 100644 pkg/core/status.go create mode 100644 pkg/core/utils.go create mode 100644 pkg/reconciler.v1/common/README.md create mode 100644 pkg/reconciler.v1/common/gang.go create mode 100644 pkg/reconciler.v1/common/gang_scheduler_framework.go create mode 100644 pkg/reconciler.v1/common/gang_volcano.go create mode 100644 pkg/reconciler.v1/common/interface.go create mode 100644 pkg/reconciler.v1/common/job.go create mode 100644 pkg/reconciler.v1/common/pod.go create mode 100644 pkg/reconciler.v1/common/pod_test.go create mode 100644 pkg/reconciler.v1/common/reconciler.go create mode 100644 pkg/reconciler.v1/common/service.go create mode 100644 pkg/reconciler.v1/common/service_test.go create mode 100644 pkg/reconciler.v1/common/utils.go create mode 100644 pkg/reconciler.v1/common/utils_test.go create mode 100644 pkg/util/counter.go create mode 100644 test_job/reconciler.v1/test_job/dummy_client.go create mode 100644 test_job/reconciler.v1/test_job/test_job_reconciler.go diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go index a47c77b3..9f6cf428 100644 --- a/pkg/controller.v1/common/job.go +++ b/pkg/controller.v1/common/job.go @@ -3,11 +3,11 @@ package common import ( "fmt" "reflect" - "sort" "time" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/expectation" + "github.com/kubeflow/common/pkg/core" commonutil "github.com/kubeflow/common/pkg/util" "github.com/kubeflow/common/pkg/util/k8sutil" @@ -49,49 +49,7 @@ func (jc *JobController) DeletePodsAndServices(runPolicy *apiv1.RunPolicy, job i // recordAbnormalPods records the active pod whose latest condition is not in True status. func (jc *JobController) recordAbnormalPods(activePods []*v1.Pod, object runtime.Object) { - for _, pod := range activePods { - // If the pod starts running, should checks the container statuses rather than the conditions. - recordContainerStatus := func(status *v1.ContainerStatus) { - if status.State.Terminated != nil && status.State.Terminated.ExitCode != 0 { - terminated := status.State.Terminated - jc.Recorder.Eventf(object, v1.EventTypeWarning, terminated.Reason, - "Error pod %s container %s exitCode: %d terminated message: %s", - pod.Name, status.Name, terminated.ExitCode, terminated.Message) - } - // The terminated state and waiting state don't simultaneously exists, checks them at the same time. - if status.State.Waiting != nil && status.State.Waiting.Message != "" { - wait := status.State.Waiting - jc.Recorder.Eventf(object, v1.EventTypeWarning, wait.Reason, - "Error pod %s container %s waiting message: %s", pod.Name, status.Name, wait.Message) - } - } - if len(pod.Status.ContainerStatuses) != 0 { - for _, status := range pod.Status.ContainerStatuses { - recordContainerStatus(&status) - } - // If the pod has container status info, that means the init container statuses are normal. - continue - } - if len(pod.Status.InitContainerStatuses) != 0 { - for _, status := range pod.Status.InitContainerStatuses { - recordContainerStatus(&status) - } - continue - } - if len(pod.Status.Conditions) == 0 { - continue - } - // Should not modify the original pod which is stored in the informer cache. - status := pod.Status.DeepCopy() - sort.Slice(status.Conditions, func(i, j int) bool { - return status.Conditions[i].LastTransitionTime.After(status.Conditions[j].LastTransitionTime.Time) - }) - condition := status.Conditions[0] - if condition.Status == v1.ConditionTrue { - continue - } - jc.Recorder.Eventf(object, v1.EventTypeWarning, condition.Reason, "Error pod %s condition message: %s", pod.Name, condition.Message) - } + core.RecordAbnormalPods(activePods, object, jc.Recorder) } // ReconcileJobs checks and updates replicas for each given ReplicaSpec. @@ -340,7 +298,7 @@ func (jc *JobController) ReconcileJobs( } // ResetExpectations reset the expectation for creates and deletes of pod/service to zero. -func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) { +func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) { for rtype := range replicas { expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) jc.Expectations.SetExpectations(expectationPodsKey, 0, 0) @@ -351,54 +309,14 @@ func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.Rep // PastActiveDeadline checks if job has ActiveDeadlineSeconds field set and if it is exceeded. func (jc *JobController) PastActiveDeadline(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus) bool { - if runPolicy.ActiveDeadlineSeconds == nil || jobStatus.StartTime == nil { - return false - } - now := metav1.Now() - start := jobStatus.StartTime.Time - duration := now.Time.Sub(start) - allowedDuration := time.Duration(*runPolicy.ActiveDeadlineSeconds) * time.Second - return duration >= allowedDuration + return core.PastActiveDeadline(runPolicy, jobStatus) } // PastBackoffLimit checks if container restartCounts sum exceeds BackoffLimit // this method applies only to pods with restartPolicy == OnFailure or Always func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pods []*v1.Pod) (bool, error) { - if runPolicy.BackoffLimit == nil { - return false, nil - } - result := int32(0) - for rtype, spec := range replicas { - if spec.RestartPolicy != apiv1.RestartPolicyOnFailure && spec.RestartPolicy != apiv1.RestartPolicyAlways { - log.Warnf("The restart policy of replica %v of the job %v is not OnFailure or Always. Not counted in backoff limit.", rtype, jobName) - continue - } - // Convert ReplicaType to lower string. - pods, err := jc.FilterPodsForReplicaType(pods, rtype) - if err != nil { - return false, err - } - for i := range pods { - po := pods[i] - if po.Status.Phase != v1.PodRunning { - continue - } - for j := range po.Status.InitContainerStatuses { - stat := po.Status.InitContainerStatuses[j] - result += stat.RestartCount - } - for j := range po.Status.ContainerStatuses { - stat := po.Status.ContainerStatuses[j] - result += stat.RestartCount - } - } - } - - if *runPolicy.BackoffLimit == 0 { - return result > 0, nil - } - return result >= *runPolicy.BackoffLimit, nil + return core.PastBackoffLimit(jobName, runPolicy, replicas, pods, jc.FilterPodsForReplicaType) } func (jc *JobController) CleanupJob(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus, job interface{}) error { @@ -435,40 +353,5 @@ func (jc *JobController) CleanupJob(runPolicy *apiv1.RunPolicy, jobStatus apiv1. } func (jc *JobController) calcPGMinResources(minMember int32, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) *v1.ResourceList { - var replicasPriority ReplicasPriority - for t, replica := range replicas { - rp := ReplicaPriority{0, *replica} - pc := replica.Template.Spec.PriorityClassName - - priorityClass, err := jc.PriorityClassLister.Get(pc) - if err != nil || priorityClass == nil { - log.Warnf("Ignore task %s priority class %s: %v", t, pc, err) - } else { - rp.priority = priorityClass.Value - } - - replicasPriority = append(replicasPriority, rp) - } - - sort.Sort(replicasPriority) - - minAvailableTasksRes := v1.ResourceList{} - podCnt := int32(0) - for _, task := range replicasPriority { - if task.Replicas == nil { - continue - } - - for i := int32(0); i < *task.Replicas; i++ { - if podCnt >= minMember { - break - } - podCnt++ - for _, c := range task.Template.Spec.Containers { - AddResourceList(minAvailableTasksRes, c.Resources.Requests, c.Resources.Limits) - } - } - } - - return &minAvailableTasksRes + return CalcPGMinResources(minMember, replicas, jc.PriorityClassLister.Get) } diff --git a/pkg/controller.v1/common/pod.go b/pkg/controller.v1/common/pod.go index ffbce85b..b29b469f 100644 --- a/pkg/controller.v1/common/pod.go +++ b/pkg/controller.v1/common/pod.go @@ -22,6 +22,7 @@ import ( apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" + "github.com/kubeflow/common/pkg/core" commonutil "github.com/kubeflow/common/pkg/util" utillabels "github.com/kubeflow/common/pkg/util/labels" trainutil "github.com/kubeflow/common/pkg/util/train" @@ -255,59 +256,13 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error) // FilterPodsForReplicaType returns pods belong to a replicaType. func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { - var result []*v1.Pod - - selector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabel: string(replicaType), - }) - - // TODO(#149): Remove deprecated selector. - deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabelDeprecated: string(replicaType), - }) - - for _, pod := range pods { - set := labels.Set(pod.Labels) - if !selector.Matches(set) && !deprecatedSelector.Matches(set) { - continue - } - result = append(result, pod) - } - return result, nil + return core.FilterPodsForReplicaType(pods, replicaType) } // getPodSlices returns a slice, which element is the slice of pod. // It gives enough information to caller to make decision to up/down scale resources. func (jc *JobController) GetPodSlices(pods []*v1.Pod, replicas int, logger *log.Entry) [][]*v1.Pod { - podSlices := make([][]*v1.Pod, calculatePodSliceSize(pods, replicas)) - for _, pod := range pods { - index, err := utillabels.ReplicaIndex(pod.Labels) - if err != nil { - logger.Warningf("Error obtaining replica index from Pod %s/%s: %v", pod.Namespace, pod.Name, err) - continue - } - if index < 0 || index >= replicas { - logger.Warningf("The label index is not expected: %d, pod: %s/%s", index, pod.Namespace, pod.Name) - } - - podSlices[index] = append(podSlices[index], pod) - } - return podSlices -} - -// calculatePodSliceSize compare max pod index with desired replicas and return larger size -func calculatePodSliceSize(pods []*v1.Pod, replicas int) int { - size := 0 - for _, pod := range pods { - index, err := utillabels.ReplicaIndex(pod.Labels) - if err != nil { - continue - } - size = MaxInt(size, index) - } - - // size comes from index, need to +1 to indicate real size - return MaxInt(size+1, replicas) + return core.GetPodSlices(pods, replicas, logger) } // ReconcilePods checks and updates pods for each given ReplicaSpec. @@ -462,7 +417,7 @@ func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, ind logger.Warning(errMsg) jc.Recorder.Event(runtimeObject, v1.EventTypeWarning, podTemplateRestartPolicyReason, errMsg) } - setRestartPolicy(podTemplate, spec) + core.SetRestartPolicy(podTemplate, spec) // if gang-scheduling is enabled: // 1. if user has specified other scheduler, we report a warning without overriding any fields. @@ -512,15 +467,6 @@ func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, ind return nil } -func setRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *apiv1.ReplicaSpec) { - // This is necessary since restartPolicyExitCode is not supported in v1.PodTemplateSpec - if spec.RestartPolicy == apiv1.RestartPolicyExitCode { - podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever - } else { - podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicy(spec.RestartPolicy) - } -} - func isNonGangSchedulerSet(replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) bool { for _, spec := range replicas { if spec.Template.Spec.SchedulerName != "" && spec.Template.Spec.SchedulerName != gangSchedulerName { diff --git a/pkg/controller.v1/common/pod_test.go b/pkg/controller.v1/common/pod_test.go index ccd5fae6..827515a5 100644 --- a/pkg/controller.v1/common/pod_test.go +++ b/pkg/controller.v1/common/pod_test.go @@ -3,13 +3,14 @@ package common import ( "testing" - v12 "github.com/kubeflow/common/test_job/test_util/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/core" testjobv1 "github.com/kubeflow/common/test_job/apis/test_job/v1" + v12 "github.com/kubeflow/common/test_job/test_util/v1" + "github.com/stretchr/testify/assert" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) func TestSetRestartPolicy(t *testing.T) { @@ -59,7 +60,7 @@ func TestSetRestartPolicy(t *testing.T) { for _, c := range testCase { spec := c.testJob.Spec.TestReplicaSpecs[c.expectedType] podTemplate := spec.Template - setRestartPolicy(&podTemplate, spec) + core.SetRestartPolicy(&podTemplate, spec) if podTemplate.Spec.RestartPolicy != c.expectedRestartPolicy { t.Errorf("Expected %s, got %s", c.expectedRestartPolicy, podTemplate.Spec.RestartPolicy) } @@ -142,7 +143,7 @@ func TestCalculatePodSliceSize(t *testing.T) { } for _, tc := range testCases { - result := calculatePodSliceSize(tc.pods, tc.replicas) + result := core.CalculatePodSliceSize(tc.pods, tc.replicas) assert.Equal(t, tc.expectedSize, result) } } diff --git a/pkg/controller.v1/common/service.go b/pkg/controller.v1/common/service.go index f85e0f4b..7eb99c41 100644 --- a/pkg/controller.v1/common/service.go +++ b/pkg/controller.v1/common/service.go @@ -20,6 +20,7 @@ import ( apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" + "github.com/kubeflow/common/pkg/core" commonutil "github.com/kubeflow/common/pkg/util" utillabels "github.com/kubeflow/common/pkg/util/labels" @@ -139,60 +140,14 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service // FilterServicesForReplicaType returns service belong to a replicaType. func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { - var result []*v1.Service - - selector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabel: string(replicaType), - }) - - // TODO(#149): Remove deprecated selector. - deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabelDeprecated: string(replicaType), - }) - - for _, service := range services { - set := labels.Set(service.Labels) - if !selector.Matches(set) && !deprecatedSelector.Matches(set) { - continue - } - result = append(result, service) - } - return result, nil + return core.FilterServicesForReplicaType(services, replicaType) } // GetServiceSlices returns a slice, which element is the slice of service. // Assume the return object is serviceSlices, then serviceSlices[i] is an // array of pointers to services corresponding to Services for replica i. func (jc *JobController) GetServiceSlices(services []*v1.Service, replicas int, logger *log.Entry) [][]*v1.Service { - serviceSlices := make([][]*v1.Service, calculateServiceSliceSize(services, replicas)) - for _, service := range services { - index, err := utillabels.ReplicaIndex(service.Labels) - if err != nil { - logger.Warningf("Error obtaining index for service %s/%s: %v", service.Namespace, service.Name, err) - continue - } - if index < 0 || index >= replicas { - logger.Warningf("The label index is not expected: %d, service: %s/%s", index, service.Namespace, service.Name) - } - - serviceSlices[index] = append(serviceSlices[index], service) - } - return serviceSlices -} - -// calculateServiceSliceSize compare max pod index with desired replicas and return larger size -func calculateServiceSliceSize(services []*v1.Service, replicas int) int { - size := 0 - for _, svc := range services { - index, err := utillabels.ReplicaIndex(svc.Labels) - if err != nil { - continue - } - size = MaxInt(size, index) - } - - // size comes from index, need to +1 to indicate real size - return MaxInt(size+1, replicas) + return core.GetServiceSlices(services, replicas, logger) } // reconcileServices checks and updates services for each given ReplicaSpec. @@ -245,23 +200,7 @@ func (jc *JobController) ReconcileServices( // GetPortsFromJob gets the ports of job container. Port could be nil, if distributed communication strategy doesn't need and no other ports that need to be exposed. func (jc *JobController) GetPortsFromJob(spec *apiv1.ReplicaSpec) (map[string]int32, error) { - ports := make(map[string]int32) - - containers := spec.Template.Spec.Containers - for _, container := range containers { - if container.Name == jc.Controller.GetDefaultContainerName() { - containerPorts := container.Ports - if len(containerPorts) == 0 { - return nil, nil - } - for _, port := range containerPorts { - ports[port.Name] = port.ContainerPort - } - return ports, nil - } - } - - return nil, fmt.Errorf("failed to find the port") + return core.GetPortsFromJob(spec, jc.Controller.GetDefaultContainerName()) } // createNewService creates a new service for the given index and type. diff --git a/pkg/controller.v1/common/service_test.go b/pkg/controller.v1/common/service_test.go index 2592be99..9e607090 100644 --- a/pkg/controller.v1/common/service_test.go +++ b/pkg/controller.v1/common/service_test.go @@ -3,6 +3,8 @@ package common import ( "testing" + "github.com/kubeflow/common/pkg/core" + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" @@ -74,7 +76,7 @@ func TestCalculateServiceSliceSize(t *testing.T) { } for _, tc := range testCases { - result := calculateServiceSliceSize(tc.services, tc.replicas) + result := core.CalculateServiceSliceSize(tc.services, tc.replicas) assert.Equal(t, tc.expectedSize, result) } } diff --git a/pkg/controller.v1/common/status.go b/pkg/controller.v1/common/status.go index c2cd075e..d6d250be 100644 --- a/pkg/controller.v1/common/status.go +++ b/pkg/controller.v1/common/status.go @@ -2,26 +2,16 @@ package common import ( apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/core" corev1 "k8s.io/api/core/v1" ) // initializeReplicaStatuses initializes the ReplicaStatuses for replica. func initializeReplicaStatuses(jobStatus *apiv1.JobStatus, rtype apiv1.ReplicaType) { - if jobStatus.ReplicaStatuses == nil { - jobStatus.ReplicaStatuses = make(map[apiv1.ReplicaType]*apiv1.ReplicaStatus) - } - - jobStatus.ReplicaStatuses[rtype] = &apiv1.ReplicaStatus{} + core.InitializeReplicaStatuses(jobStatus, rtype) } // updateJobReplicaStatuses updates the JobReplicaStatuses according to the pod. func updateJobReplicaStatuses(jobStatus *apiv1.JobStatus, rtype apiv1.ReplicaType, pod *corev1.Pod) { - switch pod.Status.Phase { - case corev1.PodRunning: - jobStatus.ReplicaStatuses[rtype].Active++ - case corev1.PodSucceeded: - jobStatus.ReplicaStatuses[rtype].Succeeded++ - case corev1.PodFailed: - jobStatus.ReplicaStatuses[rtype].Failed++ - } + core.UpdateJobReplicaStatuses(jobStatus, rtype, pod) } diff --git a/pkg/controller.v1/common/util.go b/pkg/controller.v1/common/util.go index f1800210..7bd217d2 100644 --- a/pkg/controller.v1/common/util.go +++ b/pkg/controller.v1/common/util.go @@ -16,10 +16,13 @@ package common import ( "fmt" + "sort" "strings" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "k8s.io/api/scheduling/v1beta1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -99,3 +102,44 @@ func AddResourceList(list, req, limit v1.ResourceList) { } } } + +type PriorityClassGetFunc func(string) (*v1beta1.PriorityClass, error) + +func CalcPGMinResources(minMember int32, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pcGetFunc PriorityClassGetFunc) *v1.ResourceList { + var replicasPriority ReplicasPriority + for t, replica := range replicas { + rp := ReplicaPriority{0, *replica} + pc := replica.Template.Spec.PriorityClassName + + priorityClass, err := pcGetFunc(pc) + if err != nil || priorityClass == nil { + log.Warnf("Ignore task %s priority class %s: %v", t, pc, err) + } else { + rp.priority = priorityClass.Value + } + + replicasPriority = append(replicasPriority, rp) + } + + sort.Sort(replicasPriority) + + minAvailableTasksRes := v1.ResourceList{} + podCnt := int32(0) + for _, task := range replicasPriority { + if task.Replicas == nil { + continue + } + + for i := int32(0); i < *task.Replicas; i++ { + if podCnt >= minMember { + break + } + podCnt++ + for _, c := range task.Template.Spec.Containers { + AddResourceList(minAvailableTasksRes, c.Resources.Requests, c.Resources.Limits) + } + } + } + + return &minAvailableTasksRes +} diff --git a/pkg/core/job.go b/pkg/core/job.go new file mode 100644 index 00000000..f0b67cb4 --- /dev/null +++ b/pkg/core/job.go @@ -0,0 +1,115 @@ +package core + +import ( + "sort" + "time" + + log "github.com/sirupsen/logrus" + + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/tools/record" +) + +// RecordAbnormalPods records the active pod whose latest condition is not in True status. +func RecordAbnormalPods(activePods []*v1.Pod, object runtime.Object, recorder record.EventRecorder) { + for _, pod := range activePods { + // If the pod starts running, should checks the container statuses rather than the conditions. + recordContainerStatus := func(status *v1.ContainerStatus) { + if status.State.Terminated != nil && status.State.Terminated.ExitCode != 0 { + terminated := status.State.Terminated + recorder.Eventf(object, v1.EventTypeWarning, terminated.Reason, + "Error pod %s container %s exitCode: %d terminated message: %s", + pod.Name, status.Name, terminated.ExitCode, terminated.Message) + } + // The terminated state and waiting state don't simultaneously exists, checks them at the same time. + if status.State.Waiting != nil && status.State.Waiting.Message != "" { + wait := status.State.Waiting + recorder.Eventf(object, v1.EventTypeWarning, wait.Reason, + "Error pod %s container %s waiting message: %s", pod.Name, status.Name, wait.Message) + } + } + if len(pod.Status.ContainerStatuses) != 0 { + for _, status := range pod.Status.ContainerStatuses { + recordContainerStatus(&status) + } + // If the pod has container status info, that means the init container statuses are normal. + continue + } + if len(pod.Status.InitContainerStatuses) != 0 { + for _, status := range pod.Status.InitContainerStatuses { + recordContainerStatus(&status) + } + continue + } + if len(pod.Status.Conditions) == 0 { + continue + } + // Should not modify the original pod which is stored in the informer cache. + status := pod.Status.DeepCopy() + sort.Slice(status.Conditions, func(i, j int) bool { + return status.Conditions[i].LastTransitionTime.After(status.Conditions[j].LastTransitionTime.Time) + }) + condition := status.Conditions[0] + if condition.Status == v1.ConditionTrue { + continue + } + recorder.Eventf(object, v1.EventTypeWarning, condition.Reason, "Error pod %s condition message: %s", pod.Name, condition.Message) + } +} + +// PastActiveDeadline checks if job has ActiveDeadlineSeconds field set and if it is exceeded. +func PastActiveDeadline(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus) bool { + if runPolicy.ActiveDeadlineSeconds == nil || jobStatus.StartTime == nil { + return false + } + now := metav1.Now() + start := jobStatus.StartTime.Time + duration := now.Time.Sub(start) + allowedDuration := time.Duration(*runPolicy.ActiveDeadlineSeconds) * time.Second + return duration >= allowedDuration +} + +// PastBackoffLimit checks if container restartCounts sum exceeds BackoffLimit +// this method applies only to pods with restartPolicy == OnFailure or Always +func PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy, + replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pods []*v1.Pod, + podFilterFunc func(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error)) (bool, error) { + if runPolicy.BackoffLimit == nil { + return false, nil + } + result := int32(0) + for rtype, spec := range replicas { + if spec.RestartPolicy != apiv1.RestartPolicyOnFailure && spec.RestartPolicy != apiv1.RestartPolicyAlways { + log.Warnf("The restart policy of replica %v of the job %v is not OnFailure or Always. Not counted in backoff limit.", rtype, jobName) + continue + } + // Convert ReplicaType to lower string. + pods, err := podFilterFunc(pods, rtype) + if err != nil { + return false, err + } + for i := range pods { + po := pods[i] + if po.Status.Phase != v1.PodRunning { + continue + } + for j := range po.Status.InitContainerStatuses { + stat := po.Status.InitContainerStatuses[j] + result += stat.RestartCount + } + for j := range po.Status.ContainerStatuses { + stat := po.Status.ContainerStatuses[j] + result += stat.RestartCount + } + } + } + + if *runPolicy.BackoffLimit == 0 { + return result > 0, nil + } + return result >= *runPolicy.BackoffLimit, nil +} diff --git a/pkg/core/pod.go b/pkg/core/pod.go new file mode 100644 index 00000000..30cdd1e8 --- /dev/null +++ b/pkg/core/pod.go @@ -0,0 +1,77 @@ +package core + +import ( + utillabels "github.com/kubeflow/common/pkg/util/labels" + + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + log "github.com/sirupsen/logrus" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/labels" +) + +// FilterPodsForReplicaType returns pods belong to a replicaType. +func FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { + var result []*v1.Pod + + selector := labels.SelectorFromValidatedSet(labels.Set{ + apiv1.ReplicaTypeLabel: string(replicaType), + }) + + // TODO(#149): Remove deprecated selector. + deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ + apiv1.ReplicaTypeLabelDeprecated: string(replicaType), + }) + + for _, pod := range pods { + set := labels.Set(pod.Labels) + if !selector.Matches(set) && !deprecatedSelector.Matches(set) { + continue + } + result = append(result, pod) + } + return result, nil +} + +// GetPodSlices returns a slice, which element is the slice of pod. +// It gives enough information to caller to make decision to up/down scale resources. +func GetPodSlices(pods []*v1.Pod, replicas int, logger *log.Entry) [][]*v1.Pod { + podSlices := make([][]*v1.Pod, CalculatePodSliceSize(pods, replicas)) + for _, pod := range pods { + index, err := utillabels.ReplicaIndex(pod.Labels) + if err != nil { + logger.Warningf("Error obtaining replica index from Pod %s/%s: %v", pod.Namespace, pod.Name, err) + continue + } + if index < 0 || index >= replicas { + logger.Warningf("The label index is not expected: %d, pod: %s/%s", index, pod.Namespace, pod.Name) + } + + podSlices[index] = append(podSlices[index], pod) + } + return podSlices +} + +// CalculatePodSliceSize compare max pod index with desired replicas and return larger size +func CalculatePodSliceSize(pods []*v1.Pod, replicas int) int { + size := 0 + for _, pod := range pods { + index, err := utillabels.ReplicaIndex(pod.Labels) + if err != nil { + continue + } + size = MaxInt(size, index) + } + + // size comes from index, need to +1 to indicate real size + return MaxInt(size+1, replicas) +} + +// SetRestartPolicy check the RestartPolicy defined in job spec and overwrite RestartPolicy in podTemplate if necessary +func SetRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *apiv1.ReplicaSpec) { + // This is necessary since restartPolicyExitCode is not supported in v1.PodTemplateSpec + if spec.RestartPolicy == apiv1.RestartPolicyExitCode { + podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever + } else { + podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicy(spec.RestartPolicy) + } +} diff --git a/pkg/core/service.go b/pkg/core/service.go new file mode 100644 index 00000000..73739b9f --- /dev/null +++ b/pkg/core/service.go @@ -0,0 +1,90 @@ +package core + +import ( + "fmt" + + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + utillabels "github.com/kubeflow/common/pkg/util/labels" + log "github.com/sirupsen/logrus" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/labels" +) + +// FilterServicesForReplicaType returns service belong to a replicaType. +func FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { + var result []*v1.Service + + selector := labels.SelectorFromValidatedSet(labels.Set{ + apiv1.ReplicaTypeLabel: string(replicaType), + }) + + // TODO(#149): Remove deprecated selector. + deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ + apiv1.ReplicaTypeLabelDeprecated: string(replicaType), + }) + + for _, service := range services { + set := labels.Set(service.Labels) + if !selector.Matches(set) && !deprecatedSelector.Matches(set) { + continue + } + result = append(result, service) + } + return result, nil +} + +// GetServiceSlices returns a slice, which element is the slice of service. +// Assume the return object is serviceSlices, then serviceSlices[i] is an +// array of pointers to services corresponding to Services for replica i. +func GetServiceSlices(services []*v1.Service, replicas int, logger *log.Entry) [][]*v1.Service { + serviceSlices := make([][]*v1.Service, CalculateServiceSliceSize(services, replicas)) + for _, service := range services { + index, err := utillabels.ReplicaIndex(service.Labels) + if err != nil { + logger.Warningf("Error obtaining index for service %s/%s: %v", service.Namespace, service.Name, err) + continue + } + if index < 0 || index >= replicas { + logger.Warningf("The label index is not expected: %d, service: %s/%s", index, service.Namespace, service.Name) + } + + serviceSlices[index] = append(serviceSlices[index], service) + } + return serviceSlices +} + +// CalculateServiceSliceSize compare max pod index with desired replicas and return larger size +func CalculateServiceSliceSize(services []*v1.Service, replicas int) int { + size := 0 + for _, svc := range services { + index, err := utillabels.ReplicaIndex(svc.Labels) + if err != nil { + continue + } + size = MaxInt(size, index) + } + + // size comes from index, need to +1 to indicate real size + return MaxInt(size+1, replicas) +} + +// GetPortsFromJob gets the ports of job container. Port could be nil, if distributed communication strategy doesn't need and no other ports that need to be exposed. +func GetPortsFromJob(spec *apiv1.ReplicaSpec, defaultContainerName string) (map[string]int32, error) { + ports := make(map[string]int32) + + containers := spec.Template.Spec.Containers + for _, container := range containers { + if container.Name == defaultContainerName { + containerPorts := container.Ports + if len(containerPorts) == 0 { + return nil, nil + } + for _, port := range containerPorts { + ports[port.Name] = port.ContainerPort + } + return ports, nil + } + } + + return nil, fmt.Errorf("failed to find the port") +} diff --git a/pkg/core/status.go b/pkg/core/status.go new file mode 100644 index 00000000..6b941926 --- /dev/null +++ b/pkg/core/status.go @@ -0,0 +1,27 @@ +package core + +import ( + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + corev1 "k8s.io/api/core/v1" +) + +// InitializeReplicaStatuses initializes the ReplicaStatuses for replica. +func InitializeReplicaStatuses(jobStatus *apiv1.JobStatus, rtype apiv1.ReplicaType) { + if jobStatus.ReplicaStatuses == nil { + jobStatus.ReplicaStatuses = make(map[apiv1.ReplicaType]*apiv1.ReplicaStatus) + } + + jobStatus.ReplicaStatuses[rtype] = &apiv1.ReplicaStatus{} +} + +// UpdateJobReplicaStatuses updates the JobReplicaStatuses according to the pod. +func UpdateJobReplicaStatuses(jobStatus *apiv1.JobStatus, rtype apiv1.ReplicaType, pod *corev1.Pod) { + switch pod.Status.Phase { + case corev1.PodRunning: + jobStatus.ReplicaStatuses[rtype].Active++ + case corev1.PodSucceeded: + jobStatus.ReplicaStatuses[rtype].Succeeded++ + case corev1.PodFailed: + jobStatus.ReplicaStatuses[rtype].Failed++ + } +} diff --git a/pkg/core/utils.go b/pkg/core/utils.go new file mode 100644 index 00000000..0f12a3c3 --- /dev/null +++ b/pkg/core/utils.go @@ -0,0 +1,19 @@ +package core + +import ( + "strings" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" +) + +func MaxInt(x, y int) int { + if x < y { + return y + } + return x +} + +func GenGeneralName(jobName string, rtype commonv1.ReplicaType, index string) string { + n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index + return strings.Replace(n, "/", "-", -1) +} diff --git a/pkg/reconciler.v1/common/README.md b/pkg/reconciler.v1/common/README.md new file mode 100644 index 00000000..15a0fe92 --- /dev/null +++ b/pkg/reconciler.v1/common/README.md @@ -0,0 +1,24 @@ +## Reconciler.v1 + +This is package providing most functionalities in `pkg/controller.v1` in the form of [reconciler](https://book.kubebuilder.io/cronjob-tutorial/controller-overview.html). + +To use the reconciler, following methods must be overridden according to the APIs the reconciler handles. + +```go +// GetJob returns the job that matches the request +func (r *KubeflowJobReconciler) GetJob(ctx context.Context, req ctrl.Request) (client.Object, error) + +// ExtractReplicasSpec extracts the ReplicasSpec map from this job +func (r *KubeflowJobReconciler) ExtractReplicasSpec(job client.Object) (map[commonv1.ReplicaType]*commonv1.ReplicaSpec, error) + +// ExtractRunPolicy extracts the RunPolicy from this job +func (r *KubeflowJobReconciler) ExtractRunPolicy(job client.Object) (*commonv1.RunPolicy, error) + +// ExtractJobStatus extracts the JobStatus from this job +func (r *KubeflowJobReconciler) ExtractJobStatus(job client.Object) (*commonv1.JobStatus, error) + +// IsMasterRole checks if Pod is the master Pod +func (r *KubeflowJobReconciler) IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, rtype commonv1.ReplicaType, index int) bool +``` + +A simple example can be found at `test_job/reconciler.v1/test_job/test_job_reconciler.go`. \ No newline at end of file diff --git a/pkg/reconciler.v1/common/gang.go b/pkg/reconciler.v1/common/gang.go new file mode 100644 index 00000000..79966da6 --- /dev/null +++ b/pkg/reconciler.v1/common/gang.go @@ -0,0 +1,34 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// BaseGangReconciler defines a basic gang reconciler +type BaseGangReconciler struct { + Enabled bool +} + +// GangSchedulingEnabled returns if gang-scheduling is enabled for all jobs +func (r *BaseGangReconciler) GangSchedulingEnabled() bool { + return r.Enabled +} + +// GetPodGroupName returns the name of PodGroup for this job +func (r *BaseGangReconciler) GetPodGroupName(job client.Object) string { + return job.GetName() +} diff --git a/pkg/reconciler.v1/common/gang_scheduler_framework.go b/pkg/reconciler.v1/common/gang_scheduler_framework.go new file mode 100644 index 00000000..cafd54c0 --- /dev/null +++ b/pkg/reconciler.v1/common/gang_scheduler_framework.go @@ -0,0 +1,21 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +// SchedulerFrameworkReconciler defines a gang-scheduling reconciler for Kubernetes Scheduler Framework +// TODO(zw0610): implement SchedulerFrameworkReconciler +type SchedulerFrameworkReconciler struct { + BaseGangReconciler +} diff --git a/pkg/reconciler.v1/common/gang_volcano.go b/pkg/reconciler.v1/common/gang_volcano.go new file mode 100644 index 00000000..637c5d2e --- /dev/null +++ b/pkg/reconciler.v1/common/gang_volcano.go @@ -0,0 +1,193 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + controllerv1 "github.com/kubeflow/common/pkg/controller.v1/common" + commonutil "github.com/kubeflow/common/pkg/util" + "github.com/kubeflow/common/pkg/util/k8sutil" + + log "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/api/scheduling/v1beta1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + volcano "volcano.sh/apis/pkg/apis/scheduling/v1beta1" +) + +// VolcanoReconciler defines a gang-scheduling reconciler for volcano.sh/volcano +type VolcanoReconciler struct { + BaseGangReconciler + ReconcilerUtilInterface + client.Client +} + +const ( + // VolcanoPodGroupAnnotation defines which PodGroup is linked to this Pod in annotation + VolcanoPodGroupAnnotation = "scheduling.k8s.io/group-name" +) + +// BareVolcanoReconciler returns a VolcanoReconciler pointer with minimal components defined +func BareVolcanoReconciler(client client.Client, bgReconciler *BaseGangReconciler, enabled bool) *VolcanoReconciler { + if bgReconciler == nil { + bgReconciler = &BaseGangReconciler{} + } + bgReconciler.Enabled = enabled + return &VolcanoReconciler{ + BaseGangReconciler: *bgReconciler, + Client: client, + } +} + +// OverrideForGangSchedulingInterface reset ReconcilerUtilInterface used in this VolcanoReconciler +func (r *VolcanoReconciler) OverrideForGangSchedulingInterface(ui ReconcilerUtilInterface) { + if ui != nil { + r.ReconcilerUtilInterface = ui + } +} + +// GetGangSchedulerName returns the name of Gang Scheduler will be used, which is "volcano" for VolcanoReconciler +func (r *VolcanoReconciler) GetGangSchedulerName() string { + return "volcano" +} + +// GetPodGroupForJob returns the PodGroup associated with this job +func (r *VolcanoReconciler) GetPodGroupForJob(ctx context.Context, job client.Object) (client.Object, error) { + var pg *volcano.PodGroup = nil + err := r.Get(ctx, types.NamespacedName{ + Namespace: job.GetNamespace(), + Name: r.GetPodGroupName(job), + }, pg) + + return pg, err +} + +// DeletePodGroup delete the PodGroup associated with this job +func (r *VolcanoReconciler) DeletePodGroup(ctx context.Context, job client.Object) error { + pg := &volcano.PodGroup{} + pg.SetNamespace(job.GetNamespace()) + pg.SetName(r.GetPodGroupName(job)) + + err := r.Delete(ctx, pg) + if errors.IsNotFound(err) { + return nil + } + return err +} + +// ReconcilePodGroup reconciles the PodGroup resource for this job +func (r *VolcanoReconciler) ReconcilePodGroup( + ctx context.Context, + job client.Object, + runPolicy *commonv1.RunPolicy, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error { + + minMember := k8sutil.GetTotalReplicas(replicas) + queue := "" + priorityClass := "" + var minResources *corev1.ResourceList + + if runPolicy.SchedulingPolicy != nil { + if runPolicy.SchedulingPolicy.MinAvailable != nil { + minMember = *runPolicy.SchedulingPolicy.MinAvailable + } + + if runPolicy.SchedulingPolicy.Queue != "" { + queue = runPolicy.SchedulingPolicy.Queue + } + + if runPolicy.SchedulingPolicy.PriorityClass != "" { + priorityClass = runPolicy.SchedulingPolicy.PriorityClass + } + + if runPolicy.SchedulingPolicy.MinResources != nil { + minResources = runPolicy.SchedulingPolicy.MinResources + } + } + + if minResources == nil { + minResources = r.calcPGMinResources(minMember, replicas) + } + + pgSpec := volcano.PodGroupSpec{ + MinMember: minMember, + Queue: queue, + PriorityClassName: priorityClass, + MinResources: minResources, + } + + // Check if exist + pg := &volcano.PodGroup{} + err := r.Get(ctx, types.NamespacedName{Namespace: job.GetNamespace(), Name: r.GetPodGroupName(job)}, pg) + // If Created, check updates, otherwise create it + if err == nil { + pg.Spec = pgSpec + err = r.Update(ctx, pg) + } + + if errors.IsNotFound(err) { + pg.ObjectMeta = metav1.ObjectMeta{ + Name: r.GetPodGroupName(job), + Namespace: job.GetNamespace(), + } + pg.Spec = pgSpec + err = controllerutil.SetControllerReference(job, pg, r.GetScheme()) + if err == nil { + err = r.Create(ctx, pg) + } + } + + if err != nil { + log.Warnf("Sync PodGroup %v: %v", + types.NamespacedName{Namespace: job.GetNamespace(), Name: r.GetPodGroupName(job)}, err) + return err + } + + return nil +} + +// DecoratePodForGangScheduling decorates the podTemplate before it's used to generate a pod with information for gang-scheduling +func (r *VolcanoReconciler) DecoratePodForGangScheduling(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) { + if podTemplate.Spec.SchedulerName == "" || podTemplate.Spec.SchedulerName == r.GetGangSchedulerName() { + podTemplate.Spec.SchedulerName = r.GetGangSchedulerName() + } else { + warnMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten" + commonutil.LoggerForReplica(job, rtype).Warn(warnMsg) + r.GetRecorder().Event(job, corev1.EventTypeWarning, "PodTemplateSchedulerNameAlreadySet", warnMsg) + } + + if podTemplate.Annotations == nil { + podTemplate.Annotations = map[string]string{} + } + + podTemplate.Annotations[VolcanoPodGroupAnnotation] = job.GetName() +} + +// calcPGMinResources calculates the minimal resources needed for this job. The value will be embedded into the associated PodGroup +func (r *VolcanoReconciler) calcPGMinResources(minMember int32, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) *corev1.ResourceList { + pcGetFunc := func(pc string) (*v1beta1.PriorityClass, error) { + priorityClass := &v1beta1.PriorityClass{} + err := r.Get(context.Background(), types.NamespacedName{Name: pc}, priorityClass) + return priorityClass, err + } + + return controllerv1.CalcPGMinResources(minMember, replicas, pcGetFunc) +} diff --git a/pkg/reconciler.v1/common/interface.go b/pkg/reconciler.v1/common/interface.go new file mode 100644 index 00000000..3ef70337 --- /dev/null +++ b/pkg/reconciler.v1/common/interface.go @@ -0,0 +1,260 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + + "github.com/go-logr/logr" + "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// ReconcilerUtilInterface defines the abstract interface of reconciler on utility features, such like get event +// recorder or logger +type ReconcilerUtilInterface interface { + // GetReconcilerName SHOULD be overridden if a new Reconciler is defined. The default implementation returns + // "Kubeflow Reconciler" + GetReconcilerName() string + + // GetRecorder CAN be overridden to customize EventRecorder + GetRecorder() record.EventRecorder + + // GetLogger CAN be overridden to customize logger + GetLogger(job client.Object) logr.Logger + + // GetScheme CAN be overridden to customize runtime scheme + GetScheme() *runtime.Scheme +} + +// GangSchedulingInterface defines the abstract interface for gang-scheduling related actions, such like get, create or +// delete PodGroup +type GangSchedulingInterface interface { + // OverrideForGangSchedulingInterface MUST NOT be overridden as it reset ReconcilerUtilInterface + OverrideForGangSchedulingInterface(ui ReconcilerUtilInterface) + + // GangSchedulingEnabled CAN be overridden if definition of gang-scheduling enabling changes. + GangSchedulingEnabled() bool + + // GetGangSchedulerName CAN be overridden to customize the name of gang scheduler. This name will be used to check + // the value of podTemplateSpec.Spec.SchedulerName. For volcano, it is "volcano". + GetGangSchedulerName() string + + // GetPodGroupName CAN be overridden to customize the name of PodGroup generated for the job. For example: + // podGroupName := fmt.Sprintf("%s-podgroup", job.GetName()) or podGroupName := job.GetName() + GetPodGroupName(job client.Object) string + + // GetPodGroupForJob SHOULD be overridden if Group, APIVersion or Kind changes for PodGroup. The PodGroup is + // defined in different gang-scheduler as: + // Kube-Batch: "scheduling.incubator.k8s.io/v1alpha1/PodGroup", "scheduling.sigs.dev/v1alpha2/PodGroup" + // Volcano: "scheduling.volcano.sh/v1beta1/PodGroup" + // Scheduler-Framework: "scheduling.sigs.k8s.io/v1alpha1/PodGroup" + GetPodGroupForJob(ctx context.Context, job client.Object) (client.Object, error) + + // DeletePodGroup SHOULD be overridden if Group, APIVersion and Kind changes for PodGroup. + DeletePodGroup(ctx context.Context, job client.Object) error + + // ReconcilePodGroup CAN be overridden if the logic to reconcile PodGroup changes. + ReconcilePodGroup(ctx context.Context, job client.Object, runPolicy *commonv1.RunPolicy, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error + + // DecoratePodForGangScheduling SHOULD be overridden if gang scheduler demands Pods associated with PodGroup to be + // decorated with specific requests. + DecoratePodForGangScheduling(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) +} + +// PodInterface defines the abstract interface for Pod related actions, such like get, create or delete Pod +type PodInterface interface { + // OverrideForPodInterface MUST NOT be overridden as it reset ReconcilerUtilInterface, GangSchedulingInterface, JobInterface + OverrideForPodInterface(ui ReconcilerUtilInterface, gi GangSchedulingInterface, ji JobInterface) + + // GetDefaultContainerName CAN be overridden if the default container name is not "kubeflow". + GetDefaultContainerName() string + + // GenPodName CAN be overridden to customize Pod name. + GenPodName(jobName string, rtype commonv1.ReplicaType, index string) string + + // GetPodsForJob CAN be overridden to customize how to list all pods with the job. + GetPodsForJob(ctx context.Context, job client.Object) ([]*corev1.Pod, error) + + // FilterPodsForReplicaType CAN be overridden if the linking approach between pods and replicaType changes as this + // function filters out pods for specific replica type from all pods associated with the job. + FilterPodsForReplicaType(pods []*corev1.Pod, replicaType commonv1.ReplicaType) ([]*corev1.Pod, error) + + // GetPodSlices SHOULD NOT be overridden as it generates pod slices for further pod processing. + GetPodSlices(pods []*corev1.Pod, replicas int, logger *logrus.Entry) [][]*corev1.Pod + + // ReconcilePods CAN be overridden if the logic to reconcile all Pods for the job changes. + ReconcilePods( + ctx context.Context, + job client.Object, + jobStatus *commonv1.JobStatus, + pods []*corev1.Pod, + rtype commonv1.ReplicaType, + spec *commonv1.ReplicaSpec, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error + + // CreateNewPod CAN be overridden to customize how to create a new pod. + CreateNewPod(job client.Object, rt commonv1.ReplicaType, index string, + spec *commonv1.ReplicaSpec, masterRole bool, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error + + // DeletePod CAN be overridden to customize how to delete a pod of {name} in namespace {ns}. + DeletePod(ctx context.Context, ns string, name string) error + + // DecoratePod CAN be overridden if customization to the pod is needed. The default implementation applies nothing + // to the pod. + DecoratePod(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) +} + +// ServiceInterface defines the abstract interface for Pod related actions, such like get, create or delete Service +type ServiceInterface interface { + // OverrideForServiceInterface MUST NOT be overridden as it reset ReconcilerUtilInterface, PodInterface, JobInterface + OverrideForServiceInterface(ui ReconcilerUtilInterface, pi PodInterface, ji JobInterface) + + // GetPortsFromJob CAN be overridden to customize how to find ports defined in the ReplicasSpec. + GetPortsFromJob(spec *commonv1.ReplicaSpec) (map[string]int32, error) + + // GetServicesForJob CAN be overridden to customize how to find all services associated with this job. + GetServicesForJob(ctx context.Context, job client.Object) ([]*corev1.Service, error) + + // FilterServicesForReplicaType CAN be overridden to customize how to filter out services for this Replica Type. + FilterServicesForReplicaType(services []*corev1.Service, + replicaType commonv1.ReplicaType) ([]*corev1.Service, error) + + // GetServiceSlices CAN be overridden to customize how to generate service slices. + GetServiceSlices(services []*corev1.Service, replicas int, logger *logrus.Entry) [][]*corev1.Service + + // ReconcileServices CAN be overridden to customize how to reconcile services for this job. + ReconcileServices( + job client.Object, + services []*corev1.Service, + rtype commonv1.ReplicaType, + spec *commonv1.ReplicaSpec) error + + // CreateNewService CAN be overridden to customize how to create a new service. + CreateNewService(job client.Object, rtype commonv1.ReplicaType, + spec *commonv1.ReplicaSpec, index string) error + + // DeleteService CAN be overridden to customize how to delete the service of {name} in namespace {ns}. + DeleteService(ns string, name string, job client.Object) error + + // DecorateService CAN be overridden to customize this service right before being created + DecorateService(rtype commonv1.ReplicaType, svc *corev1.Service, job client.Object) +} + +// JobInterface defines the abstract interface for Pod related actions, such like get, create or delete TFJob, +// PyTorchJob or KFJob, etc. +type JobInterface interface { + // OverrideForJobInterface MUST NOT be overridden as it reset ReconcilerUtilInterface, PodInterface, ServiceInterface, JobInterface + OverrideForJobInterface(ui ReconcilerUtilInterface, pi PodInterface, si ServiceInterface, gi GangSchedulingInterface) + + // GenLabels CAN be overridden to customize generic label generated for Pods and Services + GenLabels(jobName string) map[string]string + + // GetGroupNameLabelValue CAN be overridden to customize value used in labels regarding Group of job processed. + GetGroupNameLabelValue() string + + // GetJob MUST be overridden to get jobs with specified kind + GetJob(ctx context.Context, req ctrl.Request) (client.Object, error) + + // ExtractReplicasSpec MUST be overridden to extract ReplicasSpec from a job + ExtractReplicasSpec(job client.Object) (map[commonv1.ReplicaType]*commonv1.ReplicaSpec, error) + + // ExtractRunPolicy MUST be overridden to extract the pointer of RunPolicy from a job + ExtractRunPolicy(job client.Object) (*commonv1.RunPolicy, error) + + // ExtractJobStatus MUST be overridden to extract the pointer of JobStatus from a job + ExtractJobStatus(job client.Object) (*commonv1.JobStatus, error) + + // IsMasterRole MUST be overridden to determine whether this ReplicaType with index specified is a master role. + // MasterRole pod will have "job-role=master" set in its label + IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, rtype commonv1.ReplicaType, index int) bool + + // ReconcileJob CAN be overridden to customize how to reconcile a job. + ReconcileJob( + ctx context.Context, + job client.Object, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, + status *commonv1.JobStatus, + runPolicy *commonv1.RunPolicy) error + + // DeleteJob CAN be overridden to customize how to delete a job. + DeleteJob(job client.Object) error + + // UpdateJobStatus CAN be overridden to customize how to update job status without submitting to APIServer. + UpdateJobStatus( + job client.Object, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, + jobStatus *commonv1.JobStatus) error + + // UpdateJobStatusInAPIServer CAN be overridden to customize how to update job status directly to APIServer. + UpdateJobStatusInAPIServer(ctx context.Context, job client.Object) error + + // CleanupResources CAN be overridden to customize how to delete all resources associated with this job. + CleanupResources(runPolicy *commonv1.RunPolicy, status commonv1.JobStatus, job client.Object) error + + // CleanupJob CAN be overridden to customize how to clean up this job. + CleanupJob(runPolicy *commonv1.RunPolicy, status commonv1.JobStatus, job client.Object) error + + // RecordAbnormalPods CAN be overridden to customize how to record abnormal pods + RecordAbnormalPods(activePods []*corev1.Pod, object client.Object) + + // SetStatusForSuccessJob CAN be overridden to customize how to set status for success job + SetStatusForSuccessJob(status *commonv1.JobStatus) + + // IsFlagReplicaTypeForJobStatus CAN be overridden to customize how to determine if this ReplicaType is the + // flag ReplicaType for the status of this kind of job + IsFlagReplicaTypeForJobStatus(rtype commonv1.ReplicaType) bool + + // IsJobSucceeded CAN be overridden to customize how to determine if this job is succeeded. + IsJobSucceeded(status commonv1.JobStatus) bool + + // IsJobFailed CAN be overridden to customize how to determine if this job is failed. + IsJobFailed(status commonv1.JobStatus) bool + + // ShouldCleanUp CAN be overridden to customize how to determine if this job should be cleaned up. + ShouldCleanUp(status commonv1.JobStatus) bool + + // PastBackoffLimit CAN be overridden to customize how to determine if this job has past backoff limit. + PastBackoffLimit(jobName string, runPolicy *commonv1.RunPolicy, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, pods []*corev1.Pod) (bool, error) + + // PastActiveDeadline CAN be overridden to customize how to determine if this job has past activate deadline. + PastActiveDeadline(runPolicy *commonv1.RunPolicy, jobStatus *commonv1.JobStatus) bool +} + +// KubeflowReconcilerInterface defines the abstract interface for a base reconciler for kubeflow jobs. +type KubeflowReconcilerInterface interface { + JobInterface + PodInterface + ServiceInterface + GangSchedulingInterface + ReconcilerUtilInterface + + // OverrideForKubeflowReconcilerInterface MUST NOT be overridden as it reset ReconcilerUtilInterface, PodInterface, ServiceInterface, JobInterface, GangSchedulingInterface + OverrideForKubeflowReconcilerInterface(ji JobInterface, pi PodInterface, si ServiceInterface, gi GangSchedulingInterface, ui ReconcilerUtilInterface) + + // Reconcile CAN be overridden to customize how to handle a request. + Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) + + // SetupWithManager CAN be overridden to customize how to set up the reconciler with the manager. + SetupWithManager(mgr ctrl.Manager, obj client.Object) error +} diff --git a/pkg/reconciler.v1/common/job.go b/pkg/reconciler.v1/common/job.go new file mode 100644 index 00000000..9acecf1e --- /dev/null +++ b/pkg/reconciler.v1/common/job.go @@ -0,0 +1,478 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + "fmt" + "reflect" + "strings" + "time" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/core" + commonutil "github.com/kubeflow/common/pkg/util" + "github.com/kubeflow/common/pkg/util/k8sutil" + + "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + GroupName = "kubeflow.org" + + ReasonKey = "reason" + ReasonJobDeleted = "job deleted" + + MsgReconcileCancelled = "Reconcile Cancelled" + MsgReconcileStart = "Reconcile Starts" + + MsgGetPodsFailed = "Get Pods Failed" + MsgGetServicesFailed = "Get Services Failed" + + MsgBackoffLimitReachedTemplate = "Job %s has failed because it has reached the specified backoff limit" + MsgActiveDeadlineReachedTemplate = "Job %s has failed because it was active longer than specified deadline" + + ErrUpdateJobConditionsFailed = "failed to update job conditions" + + ErrUpdateJobErrorTemplate = "UpdateJobStatus error %v" + ErrAppendJobConditionTemplate = "Append job condition error %v" + ErrReconcilePodsTemplate = "ReconcilePods error %v" + ErrReconcileServicesTemplate = "ReconcileServices error %v" + ErrReconcileGangTemplate = "ReconcilePodGroups error %v" + ErrGetReplicasStatusFromStatusFailedTemplate = "failed to get ReplicasStatus for %s from status" + + WarnDefaultImplementationTemplate = "Warning: executing default implementation for KubeflowReconciler.%s" + WarnNotCountedInBackoffLimit = "The restart policy of replica %v of the job %v is not OnFailure or Always. Not counted in backoff limit." +) + +// KubeflowJobReconciler defines a Reconciler dealing with KubeflowJob +type KubeflowJobReconciler struct { + client.Client + ReconcilerUtilInterface + PodInterface + ServiceInterface + GangSchedulingInterface + counter *commonutil.Counter +} + +// BareKubeflowJobReconciler returns the pointer of a KubeflowJobReconciler with minimal implementation +func BareKubeflowJobReconciler(client client.Client) *KubeflowJobReconciler { + return &KubeflowJobReconciler{ + Client: client, + counter: commonutil.NewCounter(), + } +} + +// OverrideForJobInterface resets ReconcilerUtilInterface, PodInterface, ServiceInterface, GangSchedulingInterface used in KubeflowJobReconciler +func (r *KubeflowJobReconciler) OverrideForJobInterface(ui ReconcilerUtilInterface, pi PodInterface, si ServiceInterface, gi GangSchedulingInterface) { + if ui != nil { + r.ReconcilerUtilInterface = ui + } + if pi != nil { + r.PodInterface = pi + } + if si != nil { + r.ServiceInterface = si + } + if gi != nil { + r.GangSchedulingInterface = gi + } +} + +// GenLabels returns labels used for this job (based on the name of this KubeflowJob) +func (r *KubeflowJobReconciler) GenLabels(jobName string) map[string]string { + jobName = strings.Replace(jobName, "/", "-", -1) + return map[string]string{ + // TODO(#149): Remove deprecated labels. + commonv1.OperatorNameLabel: r.GetReconcilerName(), + commonv1.GroupNameLabelDeprecated: r.GetGroupNameLabelValue(), + commonv1.JobNameLabel: jobName, + commonv1.JobNameLabelDeprecated: jobName, + } +} + +// GetGroupNameLabelValue returns the Group Name for the KubeflowJob, which is "kubeflow.org" +func (r *KubeflowJobReconciler) GetGroupNameLabelValue() string { + return GroupName +} + +// ReconcileJob reconciles KubeflowJob +func (r *KubeflowJobReconciler) ReconcileJob( + ctx context.Context, + job client.Object, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, + status *commonv1.JobStatus, + runPolicy *commonv1.RunPolicy) error { + + logger := r.GetLogger(job) + logger.Info(MsgReconcileStart) + + oldStatus := status.DeepCopy() + + var err error = nil + if r.ShouldCleanUp(*status) { + if err = r.CleanupResources(runPolicy, *status, job); err != nil { + return err + } + if err = r.CleanupJob(runPolicy, *status, job); err != nil { + return err + } + if r.IsJobSucceeded(*status) { + r.SetStatusForSuccessJob(status) + } + if !reflect.DeepEqual(*oldStatus, *status) { + return r.UpdateJobStatusInAPIServer(ctx, job) + } + return nil + } + + pods, err := r.GetPodsForJob(ctx, job) + if err != nil { + logger.Info(MsgGetPodsFailed) + return err + } + + services, err := r.GetServicesForJob(ctx, job) + if err != nil { + logger.Info(MsgGetServicesFailed) + return err + } + + previousRetry, _ := r.counter.Counts(types.NamespacedName{ + Namespace: job.GetNamespace(), + Name: job.GetName(), + }.String()) + if previousRetry < 0 { + // TODO: may be we should abort here? + previousRetry = 0 + } + + activePods := k8sutil.FilterActivePods(pods) + r.RecordAbnormalPods(activePods, job) + + active := int32(len(activePods)) + failed := k8sutil.FilterPodCount(pods, corev1.PodFailed) + totalReplicas := k8sutil.GetTotalReplicas(replicas) + prevReplicasFailedNum := k8sutil.GetTotalFailedReplicas(status.ReplicaStatuses) + + var failureMessage string + jobExceedsLimit := false + exceedsBackoffLimit := false + pastBackoffLimit := false + + if runPolicy.BackoffLimit != nil { + jobHasNewFailure := failed > prevReplicasFailedNum + exceedsBackoffLimit = jobHasNewFailure && (active != totalReplicas) && + (int32(previousRetry)+1 > *runPolicy.BackoffLimit) + + pastBackoffLimit, err = r.PastBackoffLimit(job.GetName(), runPolicy, replicas, pods) + if err != nil { + return err + } + } + + if exceedsBackoffLimit || pastBackoffLimit { + // check if the number of pod restart exceeds backoff (for restart OnFailure only) + // OR if the number of failed jobs increased since the last syncJob + jobExceedsLimit = true + failureMessage = fmt.Sprintf(MsgBackoffLimitReachedTemplate, job.GetName()) + } else if r.PastActiveDeadline(runPolicy, status) { + failureMessage = fmt.Sprintf(MsgActiveDeadlineReachedTemplate, job.GetName()) + jobExceedsLimit = true + } + + if jobExceedsLimit { + if status.CompletionTime == nil { + now := metav1.Now() + status.CompletionTime = &now + } + if err = r.CleanupResources(runPolicy, *status, job); err != nil { + return err + } + if err = r.CleanupJob(runPolicy, *status, job); err != nil { + return err + } + if r.IsJobSucceeded(*status) { + r.SetStatusForSuccessJob(status) + } + + r.GetRecorder().Event(job, corev1.EventTypeNormal, commonutil.JobFailedReason, failureMessage) + + if err = commonutil.UpdateJobConditions(status, commonv1.JobFailed, commonutil.JobFailedReason, failureMessage); err != nil { + logrus.Infof(ErrAppendJobConditionTemplate, err) + return err + } + + return r.UpdateJobStatusInAPIServer(ctx, job) + } + + if r.GangSchedulingEnabled() { + err = r.ReconcilePodGroup(ctx, job, runPolicy, replicas) + if err != nil { + logrus.Warnf(ErrReconcileGangTemplate, err) + return err + } + } + + for rtype, spec := range replicas { + core.InitializeReplicaStatuses(status, rtype) + + err = r.ReconcilePods(ctx, job, status, pods, rtype, spec, replicas) + if err != nil { + logrus.Warnf(ErrReconcilePodsTemplate, err) + return err + } + + err = r.ReconcileServices(job, services, rtype, spec) + if err != nil { + logrus.Warnf(ErrReconcileServicesTemplate, err) + return err + } + } + + err = r.UpdateJobStatus(job, replicas, status) + if err != nil { + logrus.Warnf(ErrUpdateJobErrorTemplate, err) + return err + } + + if !reflect.DeepEqual(*oldStatus, status) { + return r.UpdateJobStatusInAPIServer(ctx, job) + } + + return nil +} + +// DeleteJob deletes this KubeflowJob +func (r *KubeflowJobReconciler) DeleteJob(job client.Object) error { + return r.Delete(context.Background(), job) +} + +// RecordAbnormalPods records abnormal pods during the reconciliation of jobs +func (r *KubeflowJobReconciler) RecordAbnormalPods(activePods []*corev1.Pod, object client.Object) { + core.RecordAbnormalPods(activePods, object, r.GetRecorder()) +} + +// SetStatusForSuccessJob sets the status for job that succeed +func (r *KubeflowJobReconciler) SetStatusForSuccessJob(status *commonv1.JobStatus) { + for rytpe := range status.ReplicaStatuses { + status.ReplicaStatuses[rytpe].Succeeded += status.ReplicaStatuses[rytpe].Active + status.ReplicaStatuses[rytpe].Active = 0 + } +} + +// UpdateJobStatus updates the status of this KubeflowJob WITHOUT pushing the updated status to the APIServer +func (r *KubeflowJobReconciler) UpdateJobStatus( + job client.Object, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, + jobStatus *commonv1.JobStatus) error { + + logrus.Warnf(WarnDefaultImplementationTemplate, "UpdateJobStatus") + + jobKind := job.GetObjectKind().GroupVersionKind().Kind + jobNamespacedName := types.NamespacedName{Namespace: job.GetNamespace(), Name: job.GetName()}.String() + + logger := r.GetLogger(job) + + for rtype, spec := range replicas { + status, ok := jobStatus.ReplicaStatuses[rtype] + if !ok { + return fmt.Errorf(ErrGetReplicasStatusFromStatusFailedTemplate, rtype) + } + + succeeded := status.Succeeded + expected := *(spec.Replicas) - succeeded + running := status.Active + failed := status.Failed + + logrus.Infof("%s=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d", + jobKind, jobNamespacedName, rtype, expected, running, succeeded, failed) + + if r.IsFlagReplicaTypeForJobStatus(rtype) { + if running > 0 { + msg := fmt.Sprintf("%s %s is running.", jobKind, jobNamespacedName) + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, commonutil.JobRunningReason, msg) + if err != nil { + logger.Info(ErrAppendJobConditionTemplate, err) + return err + } + } + + if expected == 0 { + msg := fmt.Sprintf("%s %s is successfully completed.", jobKind, jobNamespacedName) + logrus.Info(msg) + r.GetRecorder().Event(job, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg) + if jobStatus.CompletionTime == nil { + now := metav1.Now() + jobStatus.CompletionTime = &now + } + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobSucceeded, commonutil.JobSucceededReason, msg) + if err != nil { + logger.Info(ErrAppendJobConditionTemplate, err) + } + return nil + } + } + + if failed > 0 { + if spec.RestartPolicy == commonv1.RestartPolicyExitCode { + msg := fmt.Sprintf("%s %s is restarting because %d %s replica(s) failed.", + jobKind, jobNamespacedName, failed, rtype) + r.GetRecorder().Event(job, corev1.EventTypeWarning, commonutil.JobRestartingReason, msg) + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRestarting, commonutil.JobRestartingReason, msg) + if err != nil { + logger.Info(ErrAppendJobConditionTemplate, err) + return err + } + } else { + msg := fmt.Sprintf("%s %s is failed because %d %s replica(s) failed.", + jobKind, jobNamespacedName, failed, rtype) + if jobStatus.CompletionTime == nil { + now := metav1.Now() + jobStatus.CompletionTime = &now + } + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobFailed, commonutil.JobFailedReason, msg) + if err != nil { + logger.Info(ErrAppendJobConditionTemplate, err) + return err + } + } + } + + } + + msg := fmt.Sprintf("%s %s is running.", jobKind, jobNamespacedName) + logger.Info(msg) + + if err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, commonutil.JobRunningReason, msg); err != nil { + logger.Error(err, ErrUpdateJobConditionsFailed, jobKind) + return err + } + + return nil +} + +// UpdateJobStatusInAPIServer updates the status of this KubeflowJob in APIServer +func (r *KubeflowJobReconciler) UpdateJobStatusInAPIServer(ctx context.Context, job client.Object) error { + return r.Status().Update(ctx, job) +} + +// CleanupResources cleans up all resources associated with this KubeflowJob +func (r *KubeflowJobReconciler) CleanupResources(runPolicy *commonv1.RunPolicy, status commonv1.JobStatus, job client.Object) error { + if *runPolicy.CleanPodPolicy == commonv1.CleanPodPolicyNone { + return nil + } + ctx := context.Background() + cleanRunningPod := *runPolicy.CleanPodPolicy == commonv1.CleanPodPolicyRunning + + if err := r.DeletePodGroup(ctx, job); err != nil { + return err + } + + pods, err := r.GetPodsForJob(ctx, job) + if err != nil { + return err + } + + for _, pod := range pods { + if cleanRunningPod && pod.Status.Phase != corev1.PodRunning && pod.Status.Phase != corev1.PodPending { + continue + } + if err = r.Delete(ctx, pod); err != nil { + return err + } + // Each Pod may or may not has its service with same name + svc := &corev1.Service{} + err = r.Get(ctx, types.NamespacedName{Namespace: pod.Namespace, Name: pod.Name}, svc) + if errors.IsNotFound(err) { + continue + } + if err != nil { + return err + } + if err = r.Delete(ctx, svc); err != nil { + return err + } + + } + + return nil +} + +// CleanupJob cleans up all resources associated with this KubeflowJob as well as the job itself +func (r *KubeflowJobReconciler) CleanupJob(runPolicy *commonv1.RunPolicy, status commonv1.JobStatus, job client.Object) error { + currentTime := time.Now() + + ttl := runPolicy.TTLSecondsAfterFinished + if ttl == nil { + return nil + } + + duration := time.Second * time.Duration(*ttl) + // todo: Is the jobStatus.CompletionTime maybe nil ? + finishTime := status.CompletionTime + expireTime := finishTime.Add(duration) + + if currentTime.After(expireTime) { + err := r.DeleteJob(job) + if err != nil { + commonutil.LoggerForJob(job).Warnf("Cleanup Job error: %v.", err) + return err + } + return nil + } else { + if finishTime.After(currentTime) { + commonutil.LoggerForJob(job).Warnf("Found Job finished in the future. This is likely due to time skew in the cluster. Job cleanup will be deferred.") + } + } + return nil +} + +// IsFlagReplicaTypeForJobStatus checks if this replicaType is the flag replicaType for the status of KubeflowJob +func (r *KubeflowJobReconciler) IsFlagReplicaTypeForJobStatus(rtype commonv1.ReplicaType) bool { + logrus.Warnf(WarnDefaultImplementationTemplate, "IsFlagReplicaTypeForJobStatus") + return true +} + +// IsJobSucceeded checks if this KubeflowJob succeeded +func (r *KubeflowJobReconciler) IsJobSucceeded(status commonv1.JobStatus) bool { + return commonutil.IsSucceeded(status) +} + +// IsJobFailed checks if this KubeflowJob failed +func (r *KubeflowJobReconciler) IsJobFailed(status commonv1.JobStatus) bool { + return commonutil.IsFailed(status) +} + +// ShouldCleanUp checks if resources associated with this KubeflowJob should be cleaned up +func (r *KubeflowJobReconciler) ShouldCleanUp(status commonv1.JobStatus) bool { + return r.IsJobSucceeded(status) || r.IsJobFailed(status) +} + +// PastBackoffLimit checks if this KubeflowJob has past backoff limit +func (r *KubeflowJobReconciler) PastBackoffLimit(jobName string, runPolicy *commonv1.RunPolicy, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, pods []*corev1.Pod) (bool, error) { + return core.PastBackoffLimit(jobName, runPolicy, replicas, pods, r.FilterPodsForReplicaType) +} + +// PastActiveDeadline checks if this KubeflowJob has ActiveDeadlineSeconds field set and if it is exceeded. +func (r *KubeflowJobReconciler) PastActiveDeadline(runPolicy *commonv1.RunPolicy, jobStatus *commonv1.JobStatus) bool { + return core.PastActiveDeadline(runPolicy, *jobStatus) +} diff --git a/pkg/reconciler.v1/common/pod.go b/pkg/reconciler.v1/common/pod.go new file mode 100644 index 00000000..f6f69cab --- /dev/null +++ b/pkg/reconciler.v1/common/pod.go @@ -0,0 +1,276 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + "strconv" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + log "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/core" + commonutil "github.com/kubeflow/common/pkg/util" + trainutil "github.com/kubeflow/common/pkg/util/train" +) + +// DefaultContainerName defines the default name for container in Pod +const DefaultContainerName = "kubeflow" + +var ( + // Prometheus metrics + createdPodsCount = promauto.NewCounter(prometheus.CounterOpts{ + Name: "reconciler_created_pods_total", + Help: "The total number of created pods", + }) + deletedPodsCount = promauto.NewCounter(prometheus.CounterOpts{ + Name: "reconciler_deleted_pods_total", + Help: "The total number of deleted pods", + }) + failedPodsCount = promauto.NewCounter(prometheus.CounterOpts{ + Name: "reconciler_failed_pods_total", + Help: "The total number of failed pods", + }) +) + +// KubeflowPodReconciler defines a Pod Reconciler for KubeflowJob +type KubeflowPodReconciler struct { + client.Client + ReconcilerUtilInterface + GangSchedulingInterface + JobInterface +} + +// OverrideForPodInterface resets ReconcilerUtilInterface, GangSchedulingInterface, JobInterface for KubeflowPodReconciler +func (r *KubeflowPodReconciler) OverrideForPodInterface(ui ReconcilerUtilInterface, gi GangSchedulingInterface, ji JobInterface) { + if ui != nil { + r.ReconcilerUtilInterface = ui + } + if ji != nil { + r.JobInterface = ji + } + if gi != nil { + r.GangSchedulingInterface = gi + } +} + +// BareKubeflowPodReconciler returns a pointer of BareKubeflowPodReconciler with minimal implementation +func BareKubeflowPodReconciler(client client.Client) *KubeflowPodReconciler { + return &KubeflowPodReconciler{Client: client} +} + +// GenPodName returns the name of the Pod based on jobName, replicaType and its index +func (r *KubeflowPodReconciler) GenPodName(jobName string, rtype commonv1.ReplicaType, index string) string { + return core.GenGeneralName(jobName, rtype, index) +} + +// GetDefaultContainerName returns the default name of the container +func (r *KubeflowPodReconciler) GetDefaultContainerName() string { + return DefaultContainerName +} + +// GetPodsForJob returns all Pods associated with this job +func (r *KubeflowPodReconciler) GetPodsForJob(ctx context.Context, job client.Object) ([]*corev1.Pod, error) { + podList := &corev1.PodList{} + err := r.List(ctx, podList, client.MatchingLabels(r.GenLabels(job.GetName()))) + if err != nil { + return nil, err + } + + var pods []*corev1.Pod = nil + for _, pod := range podList.Items { + pods = append(pods, &pod) + } + + return pods, nil + // TODO: (zw0610) adding Claiming Pods +} + +// GetPodSlices generates podSlice from all Pods listed for this job +func (r *KubeflowPodReconciler) GetPodSlices(pods []*corev1.Pod, replicas int, logger *log.Entry) [][]*corev1.Pod { + return core.GetPodSlices(pods, replicas, logger) +} + +// FilterPodsForReplicaType filters out Pods for this replicaType +func (r *KubeflowPodReconciler) FilterPodsForReplicaType(pods []*corev1.Pod, replicaType commonv1.ReplicaType) ([]*corev1.Pod, error) { + return core.FilterPodsForReplicaType(pods, replicaType) +} + +// ReconcilePods reconciles Pods for this job +func (r *KubeflowPodReconciler) ReconcilePods( + ctx context.Context, + job client.Object, + jobStatus *commonv1.JobStatus, + pods []*corev1.Pod, + rtype commonv1.ReplicaType, + spec *commonv1.ReplicaSpec, + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error { + + // Convert ReplicaType to lower string. + logger := commonutil.LoggerForReplica(job, rtype) + // Get all pods for the type rt. + pods, err := r.FilterPodsForReplicaType(pods, rtype) + if err != nil { + return err + } + numReplicas := int(*spec.Replicas) + var masterRole bool + + core.InitializeReplicaStatuses(jobStatus, rtype) + + // GetPodSlices will return enough information here to make decision to add/remove/update resources. + // + // For example, let's assume we have pods with replica-index 0, 1, 2 + // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a pod with replica-index 3 will be created. + // + // If replica is 1, return a slice with size 3. [[0],[1],[2]], pod with replica-index 1 and 2 are out of range and will be deleted. + podSlices := r.GetPodSlices(pods, numReplicas, logger) + for index, podSlice := range podSlices { + if len(podSlice) > 1 { + logger.Warningf("We have too many pods for %s %d", rtype, index) + } else if len(podSlice) == 0 { + logger.Infof("Need to create new pod: %s-%d", rtype, index) + + // check if this replica is the master role + masterRole = r.IsMasterRole(replicas, rtype, index) + err = r.CreateNewPod(job, rtype, strconv.Itoa(index), spec, masterRole, replicas) + if err != nil { + return err + } + } else { + // Check the status of the current pod. + pod := podSlice[0] + + // check if the index is in the valid range, if not, we should kill the pod + if index < 0 || index >= numReplicas { + err = r.DeletePod(ctx, pod.Namespace, pod.Name) + if err != nil { + return err + } + } + + // Get the exit code of the container. + var exitCode int32 = 0xbeef // magic number + for _, status := range pod.Status.ContainerStatuses { + state := status.State + if status.Name == r.GetDefaultContainerName() && state.Terminated != nil { + exitCode = state.Terminated.ExitCode + logger.Infof("Pod: %v.%v exited with code %v", pod.Namespace, pod.Name, exitCode) + r.GetRecorder().Eventf(job, corev1.EventTypeNormal, "ExitedWithCode", "Pod: %v.%v exited with code %v", pod.Namespace, pod.Name, exitCode) + } + } + // Check if the pod is retryable. + if spec.RestartPolicy == commonv1.RestartPolicyExitCode { + if pod.Status.Phase == corev1.PodFailed && trainutil.IsRetryableExitCode(exitCode) { + failedPodsCount.Inc() + logger.Infof("Need to restart the pod: %v.%v", pod.Namespace, pod.Name) + if err = r.DeletePod(ctx, pod.Namespace, pod.Name); err != nil { + return err + } + } + } + + core.UpdateJobReplicaStatuses(jobStatus, rtype, pod) + } + } + return nil + +} + +// CreateNewPod generate Pods for this job and submits creation request to APIServer +func (r *KubeflowPodReconciler) CreateNewPod(job client.Object, rt commonv1.ReplicaType, index string, + spec *commonv1.ReplicaSpec, masterRole bool, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error { + + logger := commonutil.LoggerForReplica(job, rt) + + podLabels := r.GenLabels(job.GetName()) + podLabels[commonv1.ReplicaTypeLabel] = string(rt) + podLabels[commonv1.ReplicaIndexLabel] = index + if masterRole { + podLabels[commonv1.JobRoleLabel] = "master" + } + + podTemplate := spec.Template.DeepCopy() + + podTemplate.Name = r.GenPodName(job.GetName(), rt, index) + podTemplate.Namespace = job.GetNamespace() + if podTemplate.Labels == nil { + podTemplate.Labels = make(map[string]string) + } + + for key, value := range podLabels { + podTemplate.Labels[key] = value + } + + if podTemplate.Spec.RestartPolicy != corev1.RestartPolicy("") { + errMsg := "Restart policy in pod template will be overwritten by restart policy in replica spec" + logger.Warning(errMsg) + r.GetRecorder().Event(job, corev1.EventTypeWarning, "SettedPodTemplateRestartPolicy", errMsg) + } + if spec.RestartPolicy == commonv1.RestartPolicyExitCode { + podTemplate.Spec.RestartPolicy = corev1.RestartPolicyNever + } else { + podTemplate.Spec.RestartPolicy = corev1.RestartPolicy(spec.RestartPolicy) + } + + if r.GangSchedulingEnabled() { + r.DecoratePodForGangScheduling(rt, podTemplate, job) + } + + r.DecoratePod(rt, podTemplate, job) + + pod := &corev1.Pod{ + ObjectMeta: podTemplate.ObjectMeta, + Spec: podTemplate.Spec, + } + + err := controllerutil.SetControllerReference(job, pod, r.GetScheme()) + if err != nil { + return err + } + + err = r.Create(context.Background(), pod) + if err != nil && errors.IsTimeout(err) { + return nil + } else if err != nil { + return err + } + createdPodsCount.Inc() + return nil +} + +// DeletePod delete a Pod specified by name and namespace +func (r *KubeflowPodReconciler) DeletePod(ctx context.Context, ns string, name string) error { + pod := &corev1.Pod{} + pod.Name = name + pod.Namespace = ns + err := r.Delete(ctx, pod) + if err == nil { + deletedPodsCount.Inc() + } + return err +} + +// DecoratePod decorates podTemplate before a Pod is submitted to the APIServer +func (r *KubeflowPodReconciler) DecoratePod(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) { + // Default implementation applies nothing to podTemplate + return +} diff --git a/pkg/reconciler.v1/common/pod_test.go b/pkg/reconciler.v1/common/pod_test.go new file mode 100644 index 00000000..f2afeec3 --- /dev/null +++ b/pkg/reconciler.v1/common/pod_test.go @@ -0,0 +1,143 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common_test + +import ( + "testing" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/reconciler.v1/common" + testjobv1 "github.com/kubeflow/common/test_job/apis/test_job/v1" + "github.com/kubeflow/common/test_job/reconciler.v1/test_job" + testutilv1 "github.com/kubeflow/common/test_job/test_util/v1" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestGenPodName(t *testing.T) { + type tc struct { + testJob *testjobv1.TestJob + testRType commonv1.ReplicaType + testIndex string + expectedName string + } + testCase := []tc{ + func() tc { + tj := testutilv1.NewTestJob(1) + tj.SetName("hello-world") + return tc{ + testJob: tj, + testRType: commonv1.ReplicaType(testjobv1.TestReplicaTypeWorker), + testIndex: "1", + expectedName: "hello-world-worker-1", + } + }(), + } + + actualReconciler := test_job.NewTestReconciler() + var testReconciler common.KubeflowReconcilerInterface = actualReconciler + + for _, c := range testCase { + na := testReconciler.GenPodName(c.testJob.GetName(), c.testRType, c.testIndex) + if na != c.expectedName { + t.Errorf("Expected %s, got %s", c.expectedName, na) + } + } +} + +func PodInSlice(pod *corev1.Pod, pods []*corev1.Pod) bool { + for _, p := range pods { + if p.GetNamespace() == pod.GetNamespace() && p.GetName() == pod.GetName() { + return true + } + } + return false +} + +func TestFilterPodsForReplicaType(t *testing.T) { + type tc struct { + testPods []*corev1.Pod + testRType commonv1.ReplicaType + expectedPods []*corev1.Pod + } + testCase := []tc{ + func() tc { + tj := testutilv1.NewTestJob(3) + tj.SetName("hello-world") + + pod0 := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod0", + Namespace: "default", + Labels: map[string]string{ + commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeMaster), + }, + }, + Spec: corev1.PodSpec{}, + Status: corev1.PodStatus{}, + } + + pod1 := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: "default", + Labels: map[string]string{ + commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeWorker), + }, + }, + Spec: corev1.PodSpec{}, + Status: corev1.PodStatus{}, + } + + pod2 := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod2", + Namespace: "default", + Labels: map[string]string{ + commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeWorker), + }, + }, + Spec: corev1.PodSpec{}, + Status: corev1.PodStatus{}, + } + + allPods := []*corev1.Pod{pod0, pod1, pod2} + filteredPods := []*corev1.Pod{pod1, pod2} + + return tc{ + testPods: allPods, + testRType: commonv1.ReplicaType(testjobv1.TestReplicaTypeWorker), + expectedPods: filteredPods, + } + }(), + } + + actualReconciler := test_job.NewTestReconciler() + var testReconciler common.KubeflowReconcilerInterface = actualReconciler + + for _, c := range testCase { + filtered, err := testReconciler.FilterPodsForReplicaType(c.testPods, c.testRType) + if err != nil { + t.Errorf("FilterPodsForReplicaType returns error %v", err) + } + for _, ep := range c.expectedPods { + if !PodInSlice(ep, filtered) { + t.Errorf("Cannot found expected pod %s", ep.GetName()) + } + } + + } +} diff --git a/pkg/reconciler.v1/common/reconciler.go b/pkg/reconciler.v1/common/reconciler.go new file mode 100644 index 00000000..fb1a9b0f --- /dev/null +++ b/pkg/reconciler.v1/common/reconciler.go @@ -0,0 +1,147 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + + corev1 "k8s.io/api/core/v1" + "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/manager" +) + +// KubeflowReconciler reconciles a KubeflowJob object +type KubeflowReconciler struct { + JobInterface + PodInterface + ServiceInterface + GangSchedulingInterface + ReconcilerUtilInterface +} + +// BareKubeflowReconciler returns a pointer of KubeflowReconciler with minimal implementation +func BareKubeflowReconciler() *KubeflowReconciler { + return &KubeflowReconciler{} +} + +// DefaultKubeflowReconciler generates the default KubeflowReconciler with default sub-reconcilers fully setup +func DefaultKubeflowReconciler(mgr manager.Manager, gangEnable bool) *KubeflowReconciler { + kubeflowReconciler := BareKubeflowReconciler() + + // Generate Bare Components + jobInter := BareKubeflowJobReconciler(mgr.GetClient()) + podInter := BareKubeflowPodReconciler(mgr.GetClient()) + svcInter := BareKubeflowServiceReconciler(mgr.GetClient()) + gangInter := BareVolcanoReconciler(mgr.GetClient(), nil, gangEnable) + utilInter := BareUtilReconciler(mgr.GetEventRecorderFor(kubeflowReconciler.GetReconcilerName()), mgr.GetLogger(), mgr.GetScheme()) + + // Assign interfaces for jobInterface + jobInter.PodInterface = podInter + jobInter.ServiceInterface = svcInter + jobInter.GangSchedulingInterface = gangInter + jobInter.ReconcilerUtilInterface = utilInter + + // Assign interfaces for podInterface + podInter.JobInterface = jobInter + podInter.GangSchedulingInterface = gangInter + podInter.ReconcilerUtilInterface = utilInter + + // Assign interfaces for svcInterface + svcInter.PodInterface = podInter + svcInter.JobInterface = jobInter + svcInter.ReconcilerUtilInterface = utilInter + + // Assign interfaces for gangInterface + gangInter.ReconcilerUtilInterface = utilInter + + // Prepare KubeflowReconciler + kubeflowReconciler.JobInterface = jobInter + kubeflowReconciler.PodInterface = podInter + kubeflowReconciler.ServiceInterface = svcInter + kubeflowReconciler.GangSchedulingInterface = gangInter + kubeflowReconciler.ReconcilerUtilInterface = utilInter + + return kubeflowReconciler +} + +// OverrideForKubeflowReconcilerInterface resets JobInterface, PodInterface, ServiceInterface, GangSchedulingInterface, +// ReconcilerUtilInterface for KubeflowReconciler and its sub-reconcilers +func (r *KubeflowReconciler) OverrideForKubeflowReconcilerInterface( + ji JobInterface, + pi PodInterface, + si ServiceInterface, + gi GangSchedulingInterface, + ui ReconcilerUtilInterface) { + r.JobInterface.OverrideForJobInterface(ui, pi, si, gi) + r.PodInterface.OverrideForPodInterface(ui, gi, ji) + r.ServiceInterface.OverrideForServiceInterface(ui, pi, ji) + r.GangSchedulingInterface.OverrideForGangSchedulingInterface(ui) +} + +// Reconcile is part of the main kubernetes reconciliation loop which aims to +// move the current state of the cluster closer to the desired state. +func (r *KubeflowReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + _ = log.FromContext(ctx) + + job, err := r.GetJob(ctx, req) + if err != nil { + return ctrl.Result{}, err + } + + logger := r.GetLogger(job) + + if job.GetDeletionTimestamp() != nil { + logger.Info(MsgReconcileCancelled, ReasonKey, ReasonJobDeleted) + return ctrl.Result{}, nil + } + + scheme.Scheme.Default(job) + + // Get rid of SatisfiedExpectation + replicasSpec, err := r.ExtractReplicasSpec(job) + if err != nil { + return ctrl.Result{}, err + } + + runPolicy, err := r.ExtractRunPolicy(job) + if err != nil { + return ctrl.Result{}, err + } + + status, err := r.ExtractJobStatus(job) + if err != nil { + return ctrl.Result{}, err + } + + err = r.ReconcileJob(ctx, job, replicasSpec, status, runPolicy) + if err != nil { + logger.Info("Reconcile PyTorch Job error %v", err) + return ctrl.Result{}, err + } + + return ctrl.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager. +func (r *KubeflowReconciler) SetupWithManager(mgr ctrl.Manager, obj client.Object) error { + return ctrl.NewControllerManagedBy(mgr). + For(obj). + Owns(&corev1.Pod{}). + Owns(&corev1.Service{}). + Complete(r) +} diff --git a/pkg/reconciler.v1/common/service.go b/pkg/reconciler.v1/common/service.go new file mode 100644 index 00000000..63e7791e --- /dev/null +++ b/pkg/reconciler.v1/common/service.go @@ -0,0 +1,221 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + "strconv" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/core" + commonutil "github.com/kubeflow/common/pkg/util" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + log "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" +) + +var ( + succeededServiceCreationCount = promauto.NewCounter(prometheus.CounterOpts{ + Name: "reconciler_succeeded_service_creation_total", + Help: "The total number of succeeded service creation", + }) + failedServiceCreationCount = promauto.NewCounter(prometheus.CounterOpts{ + Name: "reconciler_failed_service_creation_total", + Help: "The total number of failed service creation", + }) +) + +// KubeflowServiceReconciler defines a Service Reconciler for KubeflowJob +type KubeflowServiceReconciler struct { + client.Client + ReconcilerUtilInterface + PodInterface + JobInterface +} + +// BareKubeflowServiceReconciler returns a pointer of KubeflowServiceReconciler with minimal implementation +func BareKubeflowServiceReconciler(client client.Client) *KubeflowServiceReconciler { + return &KubeflowServiceReconciler{ + Client: client, + } +} + +// OverrideForServiceInterface resets ReconcilerUtilInterface, PodInterface, JobInterface for KubeflowServiceReconciler +func (r *KubeflowServiceReconciler) OverrideForServiceInterface(ui ReconcilerUtilInterface, pi PodInterface, ji JobInterface) { + if ui != nil { + r.ReconcilerUtilInterface = ui + } + if pi != nil { + r.PodInterface = pi + } + if ji != nil { + r.JobInterface = ji + } +} + +// GetPortsFromJob gets the ports of job container. Port could be nil, if distributed communication strategy doesn't need and no other ports that need to be exposed. +func (r *KubeflowServiceReconciler) GetPortsFromJob(spec *commonv1.ReplicaSpec) (map[string]int32, error) { + defaultContainerName := r.GetDefaultContainerName() + return core.GetPortsFromJob(spec, defaultContainerName) +} + +// GetServicesForJob returns all services associated with this job +func (r *KubeflowServiceReconciler) GetServicesForJob(ctx context.Context, job client.Object) ([]*corev1.Service, error) { + svcList := &corev1.ServiceList{} + err := r.List(ctx, svcList, client.MatchingLabels(r.GenLabels(job.GetName()))) + if err != nil { + return nil, err + } + + var svcs []*corev1.Service = nil + for _, svc := range svcList.Items { + svcs = append(svcs, &svc) + } + + return svcs, nil +} + +// FilterServicesForReplicaType returns service belong to a replicaType. +func (r *KubeflowServiceReconciler) FilterServicesForReplicaType(services []*corev1.Service, + replicaType commonv1.ReplicaType) ([]*corev1.Service, error) { + return core.FilterServicesForReplicaType(services, replicaType) +} + +// GetServiceSlices returns the serviceSlice based on all Services listed for this job +func (r *KubeflowServiceReconciler) GetServiceSlices(services []*corev1.Service, replicas int, logger *log.Entry) [][]*corev1.Service { + return core.GetServiceSlices(services, replicas, logger) +} + +// ReconcileServices reconciles the Services for this job +func (r *KubeflowServiceReconciler) ReconcileServices( + job client.Object, + services []*corev1.Service, + rtype commonv1.ReplicaType, + spec *commonv1.ReplicaSpec) error { + + replicas := int(*spec.Replicas) + // Get all services for the type rt. + services, err := r.FilterServicesForReplicaType(services, rtype) + if err != nil { + return err + } + + // GetServiceSlices will return enough information here to make decision to add/remove/update resources. + // + // For example, let's assume we have services with replica-index 0, 1, 2 + // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created. + // + // If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted. + serviceSlices := r.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype)) + + for index, serviceSlice := range serviceSlices { + if len(serviceSlice) > 1 { + commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index) + } else if len(serviceSlice) == 0 { + commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index) + err = r.CreateNewService(job, rtype, spec, strconv.Itoa(index)) + if err != nil { + return err + } + } else { + // Check the status of the current svc. + svc := serviceSlice[0] + + // check if the index is in the valid range, if not, we should kill the svc + if index < 0 || index >= replicas { + err = r.DeleteService(svc.Namespace, svc.Name, job) + if err != nil { + return err + } + } + } + } + return nil + +} + +// CreateNewService generates Service based the job, replica info. and index and submits it to APIServer +func (r *KubeflowServiceReconciler) CreateNewService(job client.Object, rtype commonv1.ReplicaType, + spec *commonv1.ReplicaSpec, index string) error { + + // Append ReplicaTypeLabel and ReplicaIndexLabel labels. + labels := r.GenLabels(job.GetName()) + labels[commonv1.ReplicaTypeLabel] = string(rtype) + labels[commonv1.ReplicaIndexLabel] = index + + ports, err := r.GetPortsFromJob(spec) + if err != nil { + return err + } + + service := &corev1.Service{ + Spec: corev1.ServiceSpec{ + ClusterIP: "None", + Selector: labels, + Ports: []corev1.ServicePort{}, + }, + } + + // Add service ports to headless service + for name, port := range ports { + svcPort := corev1.ServicePort{Name: name, Port: port} + service.Spec.Ports = append(service.Spec.Ports, svcPort) + } + + service.Name = core.GenGeneralName(job.GetName(), rtype, index) + service.Namespace = job.GetNamespace() + service.Labels = labels + // Create OwnerReference. + err = controllerutil.SetControllerReference(job, service, r.GetScheme()) + if err != nil { + return err + } + + r.DecorateService(rtype, service, job) + + err = r.Create(context.Background(), service) + if err != nil && errors.IsTimeout(err) { + succeededServiceCreationCount.Inc() + return nil + } else if err != nil { + failedServiceCreationCount.Inc() + return err + } + succeededServiceCreationCount.Inc() + return nil +} + +// DeleteService deletes a Service specified by its name and namespace from APIServer +func (r *KubeflowServiceReconciler) DeleteService(ns string, name string, job client.Object) error { + svc := &corev1.Service{} + svc.Name = name + svc.Namespace = ns + err := r.Delete(context.Background(), svc) + if err == nil { + deletedPodsCount.Inc() + } + return err +} + +// DecorateService decorates the Service before it's submitted to APIServer +func (r *KubeflowServiceReconciler) DecorateService(rtype commonv1.ReplicaType, svc *corev1.Service, job client.Object) { + // Default implementation applies nothing to podTemplate + return +} diff --git a/pkg/reconciler.v1/common/service_test.go b/pkg/reconciler.v1/common/service_test.go new file mode 100644 index 00000000..b6d743c9 --- /dev/null +++ b/pkg/reconciler.v1/common/service_test.go @@ -0,0 +1,103 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common_test + +import ( + "reflect" + "testing" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/reconciler.v1/common" + testjobv1 "github.com/kubeflow/common/test_job/apis/test_job/v1" + "github.com/kubeflow/common/test_job/reconciler.v1/test_job" + test_utilv1 "github.com/kubeflow/common/test_job/test_util/v1" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestCreateNewService(t *testing.T) { + type tc struct { + testJob *testjobv1.TestJob + testRType commonv1.ReplicaType + testSpec *commonv1.ReplicaSpec + testIndex string + expectedService *corev1.Service + } + testCase := []tc{ + func() tc { + tj := test_utilv1.NewTestJob(3) + jobName := "testjob1" + tj.SetName(jobName) + idx := "0" + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: jobName + "-worker-" + idx, + Namespace: corev1.NamespaceDefault, + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + corev1.ServicePort{ + Name: testjobv1.DefaultPortName, + Port: testjobv1.DefaultPort, + }, + }, + ClusterIP: corev1.ClusterIPNone, + Selector: map[string]string{ + commonv1.GroupNameLabelDeprecated: testjobv1.GroupName, + commonv1.OperatorNameLabel: "Test Reconciler", + commonv1.JobNameLabelDeprecated: jobName, + commonv1.JobNameLabel: jobName, + commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeWorker), + commonv1.ReplicaIndexLabel: idx, + }, + }, + } + return tc{ + testJob: tj, + testRType: commonv1.ReplicaType(testjobv1.TestReplicaTypeWorker), + testSpec: tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker], + testIndex: idx, + expectedService: svc, + } + }(), + } + actualReconciler := test_job.NewTestReconciler() + var testReconciler common.KubeflowReconcilerInterface = actualReconciler + + for _, c := range testCase { + err := testReconciler.CreateNewService(c.testJob, c.testRType, c.testSpec, c.testIndex) + if err != nil { + t.Errorf("Got error when CreateNewService: %v", err) + continue + } + + found := false + for _, obj := range actualReconciler.DC.Cache { + if obj.GetName() == c.expectedService.GetName() && obj.GetNamespace() == c.expectedService.GetNamespace() { + found = true + svcCreated := obj.(*corev1.Service) + svcExpected := c.expectedService + if !reflect.DeepEqual(svcExpected.Spec, svcCreated.Spec) { + t.Errorf("Spec mismatch for service %s/%s", svcExpected.GetNamespace(), svcExpected.GetName()) + } + } + } + + if !found { + t.Errorf("Cannot find Service %s/%s created", c.expectedService.GetNamespace(), c.expectedService.GetName()) + } + } +} diff --git a/pkg/reconciler.v1/common/utils.go b/pkg/reconciler.v1/common/utils.go new file mode 100644 index 00000000..5ec38b63 --- /dev/null +++ b/pkg/reconciler.v1/common/utils.go @@ -0,0 +1,66 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ReconcilerName = "Kubeflow Reconciler" + +// GetReconcilerName returns the name of this reconciler, which is "Kubeflow Reconciler" +func (r *ReconcilerUtil) GetReconcilerName() string { + return ReconcilerName +} + +// ReconcilerUtil defines a reconciler with utility features +type ReconcilerUtil struct { + Recorder record.EventRecorder + Log logr.Logger + Scheme *runtime.Scheme +} + +// BareUtilReconciler returns a pointer of ReconcilerUtil with default implementation +func BareUtilReconciler( + recorder record.EventRecorder, + log logr.Logger, + scheme *runtime.Scheme) *ReconcilerUtil { + return &ReconcilerUtil{ + Recorder: recorder, + Log: log, + Scheme: scheme, + } +} + +// GetRecorder returns a record.EventRecorder +func (r *ReconcilerUtil) GetRecorder() record.EventRecorder { + return r.Recorder +} + +// GetLogger returns a logr.Logger +func (r *ReconcilerUtil) GetLogger(job client.Object) logr.Logger { + return r.Log.WithValues( + job.GetObjectKind().GroupVersionKind().Kind, + types.NamespacedName{Name: job.GetName(), Namespace: job.GetNamespace()}.String()) +} + +// GetScheme returns the pointer of runtime.Schemes that is used in this reconciler +func (r *ReconcilerUtil) GetScheme() *runtime.Scheme { + return r.Scheme +} diff --git a/pkg/reconciler.v1/common/utils_test.go b/pkg/reconciler.v1/common/utils_test.go new file mode 100644 index 00000000..e3c31e4e --- /dev/null +++ b/pkg/reconciler.v1/common/utils_test.go @@ -0,0 +1,65 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common_test + +import ( + "testing" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + testjobv1 "github.com/kubeflow/common/test_job/apis/test_job/v1" + + "github.com/kubeflow/common/pkg/reconciler.v1/common" + "github.com/kubeflow/common/test_job/reconciler.v1/test_job" +) + +func TestGenLabels(t *testing.T) { + type tc struct { + testJobName string + expectedLabel map[string]string + } + testCase := []tc{ + func() tc { + return tc{ + testJobName: "test/job1", + expectedLabel: map[string]string{ + commonv1.GroupNameLabelDeprecated: testjobv1.GroupName, + commonv1.JobNameLabel: "test-job1", + commonv1.JobNameLabelDeprecated: "test-job1", + commonv1.OperatorNameLabel: "Test Reconciler", + }, + } + }(), + } + + actualReconciler := test_job.NewTestReconciler() + var testReconciler common.KubeflowReconcilerInterface = actualReconciler + + for _, c := range testCase { + labels := testReconciler.GenLabels(c.testJobName) + if len(labels) != len(c.expectedLabel) { + t.Errorf("Expected to get %d labels, got %d labels", len(c.expectedLabel), len(labels)) + continue + } + for ek, ev := range c.expectedLabel { + if v, ok := labels[ek]; !ok { + t.Errorf("Cannot found expected key %s", ek) + } else { + if ev != v { + t.Errorf("Expected to get %s for %s, got %s", ev, ek, v) + } + } + } + } +} diff --git a/pkg/util/counter.go b/pkg/util/counter.go new file mode 100644 index 00000000..0fb5fa26 --- /dev/null +++ b/pkg/util/counter.go @@ -0,0 +1,71 @@ +package util + +import ( + "fmt" + "sync" +) + +type Counter struct { + lock sync.Mutex + data map[string]int +} + +func NewCounter() *Counter { + return &Counter{ + lock: sync.Mutex{}, + data: map[string]int{}, + } +} + +func (c *Counter) Inc(key string) { + c.lock.Lock() + defer c.lock.Unlock() + + v, ok := c.data[key] + if ok { + c.data[key] = v + 1 + return + } + c.data[key] = 0 +} + +func (c *Counter) DeleteKey(key string) { + c.lock.Lock() + defer c.lock.Lock() + + delete(c.data, key) +} + +func (c *Counter) Counts(key string) (int, error) { + c.lock.Lock() + defer c.lock.Unlock() + + v, ok := c.data[key] + if !ok { + return 0, fmt.Errorf("cannot get key %s", key) + } + var err error = nil + if v < 0 { + err = fmt.Errorf("count %s:%d is negative", key, v) + } + return v, err +} + +func (c *Counter) Dec(key string) error { + c.lock.Lock() + defer c.lock.Unlock() + + v, ok := c.data[key] + if ok { + if v > 1 { + c.data[key] = v - 1 + return nil + } + if v == 1 { + c.DeleteKey(key) + return nil + } + return fmt.Errorf("cannot minus one: key %s has value %d", key, v) + } + return fmt.Errorf("cannot find key %s", key) +} diff --git a/test_job/reconciler.v1/test_job/dummy_client.go b/test_job/reconciler.v1/test_job/dummy_client.go new file mode 100644 index 00000000..b1be597a --- /dev/null +++ b/test_job/reconciler.v1/test_job/dummy_client.go @@ -0,0 +1,60 @@ +package test_job + +import ( + "context" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + + "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type DummyClient struct { + scheme *runtime.Scheme + mapper meta.RESTMapper + client.Reader + client.Writer + client.StatusClient + Cache []client.Object +} + +func (c *DummyClient) Scheme() *runtime.Scheme { + return c.scheme +} + +func (c *DummyClient) RESTMapper() meta.RESTMapper { + return c.mapper +} + +func (c *DummyClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + c.Cache = append(c.Cache, obj) + return nil +} + +func (c *DummyClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + for idx, o := range c.Cache { + if o.GetName() == obj.GetName() && o.GetNamespace() == obj.GetNamespace() && o.GetObjectKind() == obj.GetObjectKind() { + c.Cache = append(c.Cache[:idx], c.Cache[idx+1:]...) + return nil + } + } + return errors.NewNotFound(schema.GroupResource{ + Group: obj.GetObjectKind().GroupVersionKind().Group, + Resource: obj.GetSelfLink(), + }, obj.GetName()) +} + +func (c *DummyClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + for idx, o := range c.Cache { + if o.GetName() == obj.GetName() && o.GetNamespace() == obj.GetNamespace() && o.GetObjectKind() == obj.GetObjectKind() { + c.Cache[idx] = obj + return nil + } + } + return errors.NewNotFound(schema.GroupResource{ + Group: obj.GetObjectKind().GroupVersionKind().Group, + Resource: obj.GetSelfLink(), + }, obj.GetName()) +} diff --git a/test_job/reconciler.v1/test_job/test_job_reconciler.go b/test_job/reconciler.v1/test_job/test_job_reconciler.go new file mode 100644 index 00000000..5413af42 --- /dev/null +++ b/test_job/reconciler.v1/test_job/test_job_reconciler.go @@ -0,0 +1,131 @@ +package test_job + +import ( + "context" + + "github.com/go-logr/logr" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + common_reconciler "github.com/kubeflow/common/pkg/reconciler.v1/common" + v1 "github.com/kubeflow/common/test_job/apis/test_job/v1" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +var _ common_reconciler.KubeflowReconcilerInterface = &TestReconciler{} + +type TestReconciler struct { + common_reconciler.KubeflowReconciler + DC *DummyClient + Job *v1.TestJob + Pods []*corev1.Pod + Services []*corev1.Service + PodGroup client.Object +} + +func NewTestReconciler() *TestReconciler { + scheme := runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(v1.AddToScheme(scheme)) + + kubeflowReconciler := common_reconciler.BareKubeflowReconciler() + + dummy_client := &DummyClient{} + + // Generate Bare Components + jobInter := common_reconciler.BareKubeflowJobReconciler(dummy_client) + podInter := common_reconciler.BareKubeflowPodReconciler(dummy_client) + svcInter := common_reconciler.BareKubeflowServiceReconciler(dummy_client) + gangInter := common_reconciler.BareVolcanoReconciler(dummy_client, nil, true) + utilInter := common_reconciler.BareUtilReconciler(nil, logr.FromContext(context.Background()), scheme) + + // Assign interfaces for jobInterface + jobInter.PodInterface = podInter + jobInter.ServiceInterface = svcInter + jobInter.GangSchedulingInterface = gangInter + jobInter.ReconcilerUtilInterface = utilInter + + // Assign interfaces for podInterface + podInter.JobInterface = jobInter + podInter.GangSchedulingInterface = gangInter + podInter.ReconcilerUtilInterface = utilInter + + // Assign interfaces for svcInterface + svcInter.PodInterface = podInter + svcInter.JobInterface = jobInter + svcInter.ReconcilerUtilInterface = utilInter + + // Assign interfaces for gangInterface + gangInter.ReconcilerUtilInterface = utilInter + + // Prepare KubeflowReconciler + kubeflowReconciler.JobInterface = jobInter + kubeflowReconciler.PodInterface = podInter + kubeflowReconciler.ServiceInterface = svcInter + kubeflowReconciler.GangSchedulingInterface = gangInter + kubeflowReconciler.ReconcilerUtilInterface = utilInter + + testReconciler := &TestReconciler{ + KubeflowReconciler: *kubeflowReconciler, + DC: dummy_client, + } + testReconciler.OverrideForKubeflowReconcilerInterface(testReconciler, testReconciler, testReconciler, testReconciler, testReconciler) + + return testReconciler +} + +func (r *TestReconciler) GetReconcilerName() string { + return "Test Reconciler" +} + +func (r *TestReconciler) GetJob(ctx context.Context, req ctrl.Request) (client.Object, error) { + return r.Job, nil +} + +func (r *TestReconciler) GetDefaultContainerName() string { + return v1.DefaultContainerName +} + +func (r *TestReconciler) GetPodGroupForJob(ctx context.Context, job client.Object) (client.Object, error) { + return r.PodGroup, nil +} + +func (r *TestReconciler) GetPodsForJob(ctx context.Context, job client.Object) ([]*corev1.Pod, error) { + return r.Pods, nil +} + +func (r *TestReconciler) GetServicesForJob(ctx context.Context, job client.Object) ([]*corev1.Service, error) { + return r.Services, nil +} + +func (r *TestReconciler) ExtractReplicasSpec(job client.Object) (map[commonv1.ReplicaType]*commonv1.ReplicaSpec, error) { + tj := job.(*v1.TestJob) + + rs := map[commonv1.ReplicaType]*commonv1.ReplicaSpec{} + for k, v := range tj.Spec.TestReplicaSpecs { + rs[commonv1.ReplicaType(k)] = v + } + + return rs, nil +} + +func (r *TestReconciler) ExtractRunPolicy(job client.Object) (*commonv1.RunPolicy, error) { + tj := job.(*v1.TestJob) + + return tj.Spec.RunPolicy, nil +} + +func (r *TestReconciler) ExtractJobStatus(job client.Object) (*commonv1.JobStatus, error) { + tj := job.(*v1.TestJob) + + return &tj.Status, nil +} + +func (r *TestReconciler) IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, rtype commonv1.ReplicaType, index int) bool { + return string(rtype) == string(v1.TestReplicaTypeMaster) +}