Skip to content

Commit

Permalink
feat(elastic): Implement main feature
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <ce.gao@outlook.com>
  • Loading branch information
gaocegege committed Nov 4, 2021
1 parent 51e02bc commit cdfd1be
Show file tree
Hide file tree
Showing 13 changed files with 690 additions and 61 deletions.
1 change: 0 additions & 1 deletion pkg/apis/mxnet/v1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 27 additions & 6 deletions pkg/apis/pytorch/v1/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,32 @@ func setDefaultPort(spec *v1.PodSpec) {
}
}

func setElasticPolicy(job *PyTorchJob) {
if job.Spec.ElasticPolicy != nil {
if job.Spec.ElasticPolicy.MaxReplicas != nil &&
job.Spec.ElasticPolicy.MinReplicas != nil {
return
} else if job.Spec.ElasticPolicy.MaxReplicas != nil {
// Set MinRepliacs to elasticPolicy.MaxReplicas.
job.Spec.ElasticPolicy.MinReplicas = job.Spec.ElasticPolicy.MaxReplicas
} else if job.Spec.ElasticPolicy.MinReplicas != nil {
job.Spec.ElasticPolicy.MaxReplicas = job.Spec.ElasticPolicy.MinReplicas
} else {
workerReplicas := job.Spec.PyTorchReplicaSpecs[PyTorchReplicaTypeWorker].Replicas
// Set Min and Max to worker.spec.Replicas.
job.Spec.ElasticPolicy.MaxReplicas = workerReplicas
job.Spec.ElasticPolicy.MinReplicas = workerReplicas
}
}
}

func setDefaultReplicas(spec *common.ReplicaSpec) {
if spec.Replicas == nil {
spec.Replicas = Int32(1)
}
}

func setDefaultRestartPolicy(spec *common.ReplicaSpec) {
if spec.RestartPolicy == "" {
spec.RestartPolicy = DefaultRestartPolicy
}
Expand Down Expand Up @@ -95,12 +117,11 @@ func SetDefaults_PyTorchJob(job *PyTorchJob) {
// Update the key of PyTorchReplicaSpecs to camel case.
setTypeNamesToCamelCase(job)

for rType, spec := range job.Spec.PyTorchReplicaSpecs {
// Set default replicas to 1.
for _, spec := range job.Spec.PyTorchReplicaSpecs {
setDefaultReplicas(spec)
if rType == PyTorchReplicaTypeMaster {
// Set default port to pytorch container of Master.
setDefaultPort(&spec.Template.Spec)
}
setDefaultRestartPolicy(spec)
setDefaultPort(&spec.Template.Spec)
}
// Set default elastic policy.
setElasticPolicy(job)
}
160 changes: 160 additions & 0 deletions pkg/apis/pytorch/v1/defaults_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package v1

import (
"testing"

"github.com/onsi/ginkgo"
"github.com/onsi/gomega"

v1 "github.com/kubeflow/common/pkg/apis/common/v1"
)

func TestSetElasticPolicy(t *testing.T) {
gomega.RegisterFailHandler(ginkgo.Fail)

type args struct {
job *PyTorchJob
}
type result struct {
expectedMinReplicas *int32
expectedMaxReplicas *int32
}
tests := []struct {
name string
args args
result result
}{
{
name: "minReplicas and maxReplicas to null",
args: args{
job: &PyTorchJob{
Spec: PyTorchJobSpec{
ElasticPolicy: &ElasticPolicy{},
PyTorchReplicaSpecs: map[v1.ReplicaType]*v1.ReplicaSpec{
PyTorchReplicaTypeWorker: {
Replicas: int32Ptr(1),
},
},
},
},
},
result: result{
expectedMinReplicas: int32Ptr(1),
expectedMaxReplicas: int32Ptr(1),
},
},
{
name: "minReplicas and maxReplicas to 1",
args: args{
job: &PyTorchJob{
Spec: PyTorchJobSpec{
ElasticPolicy: &ElasticPolicy{
MaxReplicas: int32Ptr(1),
MinReplicas: int32Ptr(1),
},
PyTorchReplicaSpecs: map[v1.ReplicaType]*v1.ReplicaSpec{
PyTorchReplicaTypeWorker: {
Replicas: int32Ptr(1),
},
},
},
},
},
result: result{
expectedMinReplicas: int32Ptr(1),
expectedMaxReplicas: int32Ptr(1),
},
},
{
name: "minReplicas and maxReplicas to 1",
args: args{
job: &PyTorchJob{
Spec: PyTorchJobSpec{
ElasticPolicy: &ElasticPolicy{
MaxReplicas: int32Ptr(1),
MinReplicas: int32Ptr(1),
},
PyTorchReplicaSpecs: map[v1.ReplicaType]*v1.ReplicaSpec{
PyTorchReplicaTypeWorker: {
Replicas: int32Ptr(1),
},
},
},
},
},
result: result{
expectedMinReplicas: int32Ptr(1),
expectedMaxReplicas: int32Ptr(1),
},
},
{
name: "minReplicas to null, maxRepliacs to 1",
args: args{
job: &PyTorchJob{
Spec: PyTorchJobSpec{
ElasticPolicy: &ElasticPolicy{
MaxReplicas: int32Ptr(1),
MinReplicas: nil,
},
PyTorchReplicaSpecs: map[v1.ReplicaType]*v1.ReplicaSpec{
PyTorchReplicaTypeWorker: {
Replicas: int32Ptr(1),
},
},
},
},
},
result: result{
expectedMinReplicas: int32Ptr(1),
expectedMaxReplicas: int32Ptr(1),
},
},
{
name: "maxRepliacs to null, minReplicas to 1",
args: args{
job: &PyTorchJob{
Spec: PyTorchJobSpec{
ElasticPolicy: &ElasticPolicy{
MaxReplicas: nil,
MinReplicas: int32Ptr(1),
},
PyTorchReplicaSpecs: map[v1.ReplicaType]*v1.ReplicaSpec{
PyTorchReplicaTypeWorker: {
Replicas: int32Ptr(1),
},
},
},
},
},
result: result{
expectedMinReplicas: int32Ptr(1),
expectedMaxReplicas: int32Ptr(1),
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
setElasticPolicy(test.args.job)
if test.result.expectedMinReplicas != nil {
gomega.Expect(test.args.job.Spec.ElasticPolicy.MinReplicas).
To(gomega.Equal(test.result.expectedMinReplicas))
} else {
gomega.Expect(test.args.job.Spec.ElasticPolicy.MinReplicas).
To(gomega.BeNil())
}

if test.result.expectedMaxReplicas != nil {
gomega.Expect(test.args.job.Spec.ElasticPolicy.MaxReplicas).
To(gomega.Equal(test.result.expectedMaxReplicas))
} else {
gomega.Expect(test.args.job.Spec.ElasticPolicy.MaxReplicas).
To(gomega.BeNil())
}
})
}
}

func int32Ptr(n int) *int32 {
val := int32(n)
return &val
}
118 changes: 117 additions & 1 deletion pkg/apis/pytorch/v1/openapi_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit cdfd1be

Please sign in to comment.