Skip to content

Commit

Permalink
Adapt tfjob to commonized kubeflowjob
Browse files Browse the repository at this point in the history
Signed-off-by: Yuki Iwai <yuki.iwai.tz@gmail.com>
  • Loading branch information
tenzen-y committed Aug 18, 2023
1 parent 2bdec4a commit b96cffe
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 44 deletions.
51 changes: 7 additions & 44 deletions pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,51 +85,14 @@ func (j *JobControl) JobStatus() kftraining.JobStatus {
return j.Status
}

func (j *JobControl) OrderedReplicaTypes(replicaSpecs map[kftraining.ReplicaType]*kftraining.ReplicaSpec) []kftraining.ReplicaType {
result := make([]kftraining.ReplicaType, 0, 5)
if _, ok := replicaSpecs[kftraining.TFJobReplicaTypeChief]; ok {
result = append(result, kftraining.TFJobReplicaTypeChief)
func (j *JobControl) OrderedReplicaTypes() []kftraining.ReplicaType {
return []kftraining.ReplicaType{
kftraining.TFJobReplicaTypeChief,
kftraining.TFJobReplicaTypeMaster,
kftraining.TFJobReplicaTypePS,
kftraining.TFJobReplicaTypeWorker,
kftraining.TFJobReplicaTypeEval,
}
if _, ok := replicaSpecs[kftraining.TFJobReplicaTypeMaster]; ok {
result = append(result, kftraining.TFJobReplicaTypeMaster)
}
if _, ok := replicaSpecs[kftraining.TFJobReplicaTypePS]; ok {
result = append(result, kftraining.TFJobReplicaTypePS)
}
if _, ok := replicaSpecs[kftraining.TFJobReplicaTypeWorker]; ok {
result = append(result, kftraining.TFJobReplicaTypeWorker)
}
if _, ok := replicaSpecs[kftraining.TFJobReplicaTypeEval]; ok {
result = append(result, kftraining.TFJobReplicaTypeEval)
}
return result
}

// PriorityClass calculates the priorityClass name needed for workload according to the following priorities:
// 1. .spec.runPolicy.schedulingPolicy.priorityClass
// 2. .spec.replicaSpecs[Chief].template.spec.priorityClassName
// 3. .spec.replicaSpecs[Master].template.spec.priorityClassName
// 4. .spec.replicaSpecs[ParameterServer].template.spec.priorityClassName
// 5. .spec.replicaSpecs[Worker].template.spec.priorityClassName
// 6. .spec.replicaSpecs[Evaluator].template.spec.priorityClassName
//
// This function is inspired by an analogous one in mpi-controller:
// https://github.com/kubeflow/mpi-operator/blob/5946ef4157599a474ab82ff80e780d5c2546c9ee/pkg/controller/podgroup.go#L69-L72
func (j *JobControl) PriorityClass() string {
if j.Spec.RunPolicy.SchedulingPolicy != nil && len(j.Spec.RunPolicy.SchedulingPolicy.PriorityClass) != 0 {
return j.Spec.RunPolicy.SchedulingPolicy.PriorityClass
} else if m := j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypeChief]; m != nil && len(m.Template.Spec.PriorityClassName) != 0 {
return m.Template.Spec.PriorityClassName
} else if m = j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypeMaster]; m != nil && len(m.Template.Spec.PriorityClassName) != 0 {
return m.Template.Spec.PriorityClassName
} else if m = j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypePS]; m != nil && len(m.Template.Spec.PriorityClassName) != 0 {
return m.Template.Spec.PriorityClassName
} else if m = j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypeWorker]; m != nil && len(m.Template.Spec.PriorityClassName) != 0 {
return m.Template.Spec.PriorityClassName
} else if m = j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypeEval]; m != nil && len(m.Template.Spec.PriorityClassName) != 0 {
return m.Template.Spec.PriorityClassName
}
return ""
}

func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error {
Expand Down
66 changes: 66 additions & 0 deletions pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package tfjob
import (
"testing"

"github.com/google/go-cmp/cmp"
kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
v1 "k8s.io/api/core/v1"
)
Expand Down Expand Up @@ -97,6 +98,20 @@ func TestCalcPriorityClassName(t *testing.T) {
},
},
},
kftraining.TFJobReplicaTypeWorker: {
Template: v1.PodTemplateSpec{
Spec: v1.PodSpec{
PriorityClassName: "worker-priority",
},
},
},
kftraining.TFJobReplicaTypeEval: {
Template: v1.PodTemplateSpec{
Spec: v1.PodSpec{
PriorityClassName: "eval-priority",
},
},
},
},
},
},
Expand Down Expand Up @@ -167,3 +182,54 @@ func TestCalcPriorityClassName(t *testing.T) {
})
}
}

func TestOrderedReplicaType(t *testing.T) {
testcases := map[string]struct {
job kftraining.TFJob
wantReplicaTypes []kftraining.ReplicaType
}{
"job has no replicas": {
job: kftraining.TFJob{},
wantReplicaTypes: []kftraining.ReplicaType{},
},
"job has all replicas": {
job: kftraining.TFJob{
Spec: kftraining.TFJobSpec{
TFReplicaSpecs: map[kftraining.ReplicaType]*kftraining.ReplicaSpec{
kftraining.TFJobReplicaTypePS: {},
kftraining.TFJobReplicaTypeEval: {},
kftraining.TFJobReplicaTypeWorker: {},
kftraining.TFJobReplicaTypeChief: {},
kftraining.TFJobReplicaTypeMaster: {},
},
},
},
wantReplicaTypes: []kftraining.ReplicaType{
kftraining.TFJobReplicaTypeChief,
kftraining.TFJobReplicaTypeMaster,
kftraining.TFJobReplicaTypePS,
kftraining.TFJobReplicaTypeWorker,
kftraining.TFJobReplicaTypeEval,
},
},
"job has only worker replicas": {
job: kftraining.TFJob{
Spec: kftraining.TFJobSpec{
TFReplicaSpecs: map[kftraining.ReplicaType]*kftraining.ReplicaSpec{
kftraining.TFJobReplicaTypeWorker: {},
},
},
},
wantReplicaTypes: []kftraining.ReplicaType{kftraining.PyTorchJobReplicaTypeWorker},
},
}
for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
tfJob := fromObject(&tc.job)
gotReplicaTypes := tfJob.OrderedReplicaTypes()
if diff := cmp.Diff(tc.wantReplicaTypes, gotReplicaTypes); len(diff) != 0 {
t.Errorf("Unexpected response (want: %v, got: %v)", tc.wantReplicaTypes, gotReplicaTypes)
}
})
}
}

0 comments on commit b96cffe

Please sign in to comment.