Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
add mpi job
Browse files Browse the repository at this point in the history
Signed-off-by: Yubo Wang <yubwang@linkedin.com>
  • Loading branch information
Yubo Wang committed Apr 27, 2023
1 parent 099bba0 commit 9626895
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 47 deletions.
5 changes: 4 additions & 1 deletion go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res
return restartPolicyMap[flyteRestartPolicy]
}

func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources) (*v1.PodSpec, error) {
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) (*v1.PodSpec, error) {
for idx, c := range podSpec.Containers {
if c.Name == containerName {
if image != "" {
Expand All @@ -240,6 +240,9 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri
}
podSpec.Containers[idx].Resources = *resources
}
if args != nil && len(args) != 0 {
podSpec.Containers[idx].Args = args
}
}
}
return podSpec, nil
Expand Down
147 changes: 103 additions & 44 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"

flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand Down Expand Up @@ -48,77 +50,134 @@ func (mpiOperatorResourceHandler) BuildIdentityResource(ctx context.Context, tas
// Defines a func to create the full resource object that will be posted to k8s.
func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
taskTemplateConfig := taskTemplate.GetConfig()

if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error())
} else if taskTemplate == nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification")
}

mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

workers := mpiTaskExtraArgs.GetNumWorkers()
launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas()
slots := mpiTaskExtraArgs.GetSlots()

podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}
common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.MPIJobDefaultContainerName)

// workersPodSpec is deepCopy of podSpec submitted by flyte
workersPodSpec := podSpec.DeepCopy()
var launcherReplica = common.ReplicaEntry{
ReplicaNum: int32(1),
PodSpec: podSpec.DeepCopy(),
RestartPolicy: commonOp.RestartPolicyNever,
}
var workerReplica = common.ReplicaEntry{
ReplicaNum: int32(0),
PodSpec: podSpec.DeepCopy(),
RestartPolicy: commonOp.RestartPolicyNever,
}
slots := int32(1)
runPolicy := commonOp.RunPolicy{}

if taskTemplate.TaskTypeVersion == 0 {
mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

// If users don't specify "worker_spec_command" in the task config, the command/args are empty.
// However, in some cases, the workers need command/args.
// For example, in horovod tasks, each worker runs a command launching ssh daemon.
workerReplica.ReplicaNum = mpiTaskExtraArgs.GetNumWorkers()
launcherReplica.ReplicaNum = mpiTaskExtraArgs.GetNumLauncherReplicas()
slots = mpiTaskExtraArgs.GetSlots()

workerSpecCommand := []string{}
if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok {
workerSpecCommand = strings.Split(val, " ")
}
// V1 requires passing worker command as template config parameter
taskTemplateConfig := taskTemplate.GetConfig()
workerSpecCommand := []string{}
if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok {
workerSpecCommand = strings.Split(val, " ")
}

for k := range workerReplica.PodSpec.Containers {
if workerReplica.PodSpec.Containers[k].Name == kubeflowv1.MPIJobDefaultContainerName {
workerReplica.PodSpec.Containers[k].Args = workerSpecCommand
workerReplica.PodSpec.Containers[k].Command = []string{}
}
}

} else if taskTemplate.TaskTypeVersion == 1 {
kfMPITaskExtraArgs := kfplugins.DistributedMPITrainingTask{}

err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfMPITaskExtraArgs)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

launcherReplicaSpec := kfMPITaskExtraArgs.GetLauncherReplicas()
if launcherReplicaSpec != nil {
// flyte commands will be passed as args to the container
common.OverrideContainerSpec(
launcherReplica.PodSpec,
kubeflowv1.MPIJobDefaultContainerName,
launcherReplicaSpec.GetImage(),
launcherReplicaSpec.GetResources(),
launcherReplicaSpec.GetCommand(),
)
launcherReplica.RestartPolicy =
commonOp.RestartPolicy(
common.ParseRestartPolicy(launcherReplicaSpec.GetRestartPolicy()),
)
}

workerReplicaSpec := kfMPITaskExtraArgs.GetWorkerReplicas()
if workerReplicaSpec != nil {
common.OverrideContainerSpec(
workerReplica.PodSpec,
kubeflowv1.MPIJobDefaultContainerName,
workerReplicaSpec.GetImage(),
workerReplicaSpec.GetResources(),
workerReplicaSpec.GetCommand(),
)
workerReplica.RestartPolicy =
commonOp.RestartPolicy(
common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()),
)
workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas()
}

if kfMPITaskExtraArgs.GetRunPolicy() != nil {
runPolicy = common.ParseRunPolicy(*kfMPITaskExtraArgs.GetRunPolicy())
}

for k := range workersPodSpec.Containers {
workersPodSpec.Containers[k].Args = workerSpecCommand
workersPodSpec.Containers[k].Command = []string{}
} else {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification,
"Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion)
}

if workers == 0 {
if workerReplica.ReplicaNum == 0 {
return nil, fmt.Errorf("number of worker should be more then 0")
}
if launcherReplicas == 0 {
if launcherReplica.ReplicaNum == 0 {
return nil, fmt.Errorf("number of launch worker should be more then 0")
}

jobSpec := kubeflowv1.MPIJobSpec{
SlotsPerWorker: &slots,
MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{},
}

for _, t := range []struct {
podSpec v1.PodSpec
replicaNum *int32
replicaType commonOp.ReplicaType
}{
{*podSpec, &launcherReplicas, kubeflowv1.MPIJobReplicaTypeLauncher},
{*workersPodSpec, &workers, kubeflowv1.MPIJobReplicaTypeWorker},
} {
if *t.replicaNum > 0 {
jobSpec.MPIReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{
Replicas: t.replicaNum,
SlotsPerWorker: &slots,
RunPolicy: runPolicy,
MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.MPIJobReplicaTypeLauncher: {
Replicas: &launcherReplica.ReplicaNum,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: t.podSpec,
Spec: *launcherReplica.PodSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
}
}
RestartPolicy: launcherReplica.RestartPolicy,
},
kubeflowv1.MPIJobReplicaTypeWorker: {
Replicas: &workerReplica.ReplicaNum,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *workerReplica.PodSpec,
},
RestartPolicy: workerReplica.RestartPolicy,
},
},
}

job := &kubeflowv1.MPIJob{
Expand Down
136 changes: 134 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"

"github.com/flyteorg/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand Down Expand Up @@ -68,9 +69,24 @@ func dummyMPICustomObj(workers int32, launcher int32, slots int32) *plugins.Dist
}
}

func dummyMPITaskTemplate(id string, mpiCustomObj *plugins.DistributedMPITrainingTask) *core.TaskTemplate {
func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate {

var mpiObjJSON string
var err error

for _, arg := range args {
switch t := arg.(type) {
case *kfplugins.DistributedMPITrainingTask:
var mpiCustomObj = t
mpiObjJSON, err = utils.MarshalToString(mpiCustomObj)
case *plugins.DistributedMPITrainingTask:
var mpiCustomObj = t
mpiObjJSON, err = utils.MarshalToString(mpiCustomObj)
default:
err = fmt.Errorf("Unkonw input type %T", t)
}
}

mpiObjJSON, err := utils.MarshalToString(mpiCustomObj)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -427,3 +443,119 @@ func TestReplicaCounts(t *testing.T) {
})
}
}

func TestBuildResourceMPIV1(t *testing.T) {
launcherCommand := []string{"python", "launcher.py"}
workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"}
taskConfig := &kfplugins.DistributedMPITrainingTask{
LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Image: testImage,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
},
},
Command: launcherCommand,
},
WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
},
},
Command: workerCommand,
},
Slots: int32(1),
}

launcherResourceRequirements := &corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("250m"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("500m"),
},
}

workerResourceRequirements := &corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1024m"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2048m"),
},
}

mpiResourceHandler := mpiOperatorResourceHandler{}

taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, resource)

mpiJob, ok := resource.(*kubeflowv1.MPIJob)
assert.True(t, ok)
assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas)
assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas)
assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker)
assert.Equal(t, *launcherResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Resources)
assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources)
assert.Equal(t, launcherCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args)
assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args)
}

func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) {
workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"}

taskConfig := &kfplugins.DistributedMPITrainingTask{
WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{
Replicas: 100,
Resources: &core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "1024m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "2048m"},
},
},
Command: []string{"/usr/sbin/sshd", "/.sshd_config"},
},
Slots: int32(1),
}

workerResourceRequirements := &corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1024m"),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2048m"),
},
}

mpiResourceHandler := mpiOperatorResourceHandler{}

taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, resource)

mpiJob, ok := resource.(*kubeflowv1.MPIJob)
assert.True(t, ok)
assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas)
assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas)
assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker)
assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources)
assert.Equal(t, testArgs, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args)
assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args)
}
2 changes: 2 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
kubeflowv1.PytorchJobDefaultContainerName,
masterReplicaSpec.GetImage(),
masterReplicaSpec.GetResources(),
nil,
)
masterReplica.RestartPolicy =
commonOp.RestartPolicy(
Expand All @@ -121,6 +122,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
kubeflowv1.PytorchJobDefaultContainerName,
workerReplicaSpec.GetImage(),
workerReplicaSpec.GetResources(),
nil,
)
workerReplica.RestartPolicy =
commonOp.RestartPolicy(
Expand Down
Loading

0 comments on commit 9626895

Please sign in to comment.