Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge 2065a96 into 76a80ec
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored May 9, 2023
2 parents 76a80ec + 2065a96 commit dce4e46
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
26 changes: 13 additions & 13 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pytorch

import (
"context"
"fmt"
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
Expand Down Expand Up @@ -69,9 +68,6 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName)

workers := pytorchTaskExtraArgs.GetWorkers()
if workers == 0 {
return nil, fmt.Errorf("number of worker should be more then 0")
}

var jobSpec kubeflowv1.PyTorchJobSpec

Expand Down Expand Up @@ -115,23 +111,27 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.PyTorchJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
},
}

if workers > 0 {
jobSpec.PyTorchReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker] = &commonOp.ReplicaSpec{
Replicas: &workers,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
}
}
}
job := &kubeflowv1.PyTorchJob{
TypeMeta: metav1.TypeMeta{
Kind: kubeflowv1.PytorchJobKind,
APIVersion: kubeflowv1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
Spec: jobSpec,
ObjectMeta: *objectMeta,
}

return job, nil
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func TestReplicaCounts(t *testing.T) {
contains []commonOp.ReplicaType
notContains []commonOp.ReplicaType
}{
{"NoWorkers", 0, true, nil, nil},
{"NoWorkers", 0, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster}, nil},
{"Works", 1, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster, kubeflowv1.PyTorchJobReplicaTypeWorker}, []commonOp.ReplicaType{}},
} {
t.Run(test.name, func(t *testing.T) {
Expand Down

0 comments on commit dce4e46

Please sign in to comment.