diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index 6ed4694a93..13cbee9889 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -504,6 +504,15 @@ func (r *PyTorchJobReconciler) SetClusterSpec(job interface{}, podTemplate *core return nil } +func (r *PyTorchJobReconciler) IsMasterRole(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec, + rtype kubeflowv1.ReplicaType, index int) bool { + if _, ok := replicas[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok { + return string(rtype) == strings.ToLower(string(kubeflowv1.PyTorchJobReplicaTypeMaster)) + } + // else check if it is worker with index 0 + return string(rtype) == strings.ToLower(string(kubeflowv1.PyTorchJobReplicaTypeWorker)) && index == 0 +} + func (r *PyTorchJobReconciler) GetDefaultContainerName() string { return kubeflowv1.PyTorchJobDefaultContainerName } @@ -512,11 +521,6 @@ func (r *PyTorchJobReconciler) GetDefaultContainerPortName() string { return kubeflowv1.PyTorchJobDefaultPortName } -func (r *PyTorchJobReconciler) IsMasterRole(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec, - rtype kubeflowv1.ReplicaType, index int) bool { - return string(rtype) == string(kubeflowv1.PyTorchJobReplicaTypeMaster) -} - // onOwnerCreateFunc modify creation condition. func (r *PyTorchJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool { return func(e event.CreateEvent) bool {