Skip to content

Commit

Permalink
Put the full PyTorch prototype in the jsonnet file. (kubeflow#1119)
Browse files Browse the repository at this point in the history
* The current pattern in Kubeflow is to put the complete prototype in
  the jsonnet file. This way the result of ks generate is a .jsonnet
  file containing a full spec. This makes it easy for users to do
  complex modifications starting with the prototype as an example.

* Configure the prototype to do mnist by default.

Fix kubeflow#1114
  • Loading branch information
jlewi committed Jul 11, 2018
1 parent 6b14d8e commit 225c8be
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 82 deletions.
75 changes: 72 additions & 3 deletions kubeflow/pytorch-job/prototypes/pytorch-job.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// @param name string Name to give to each of the components
// @optionalParam namespace string null Namespace to use for the components. It is automatically inherited from the environment if not set.
// @optionalParam args string null Comma separated list of arguments to pass to the job
// @optionalParam image string null The docker image to use for the job.
// @optionalParam image string gcr.io/kubeflow-examples/pytorch-dist-mnist:v20180702-a57993c The docker image to use for the job.
// @optionalParam image_gpu string null The docker image to use when using GPUs.
// @optionalParam num_masters number 1 The number of masters to use
// @optionalParam num_ps number 0 The number of ps to use
Expand All @@ -14,6 +14,75 @@

local k = import "k.libsonnet";

local all = import "kubeflow/pytorch-job/pytorch-job.libsonnet";
local util = {
pytorchJobReplica(replicaType, number, args, image, numGpus=0)::
local baseContainer = {
image: image,
name: "pytorch",
};
local containerArgs = if std.length(args) > 0 then
{
args: args,
}
else {};
local resources = if numGpus > 0 then {
resources: {
limits: {
"nvidia.com/gpu": numGpus,
},
},
} else {};
if number > 0 then
{
replicas: number,
template: {
spec: {
containers: [
baseContainer + containerArgs + resources,
],
restartPolicy: "OnFailure",
},
},
replicaType: replicaType,
}
else {},
};

std.prune(k.core.v1.list.new(all.pyTorchJobPrototype(params, env)))
local namespace = env.namespace;
local name = params.name;

local argsParam = params.args;
local args =
if argsParam == "null" then
[]
else
std.split(argsParam, ",");

local image = params.image;
local imageGpu = params.image_gpu;
local numMasters = params.num_masters;
local numWorkers = params.num_workers;
local numGpus = params.num_gpus;

local workerSpec = if numGpus > 0 then
util.pytorchJobReplica("WORKER", numWorkers, args, imageGpu, numGpus)
else
util.pytorchJobReplica("WORKER", numWorkers, args, image);

local masterSpec = util.pytorchJobReplica("MASTER", numMasters, args, image);
local replicas = [masterSpec, workerSpec];


local job = {
apiVersion: "kubeflow.org/v1alpha1",
kind: "PyTorchJob",
metadata: {
name: name,
namespace: namespace,
},
spec: {
replicaSpecs: replicas,
},
};

std.prune(k.core.v1.list.new([job]))
79 changes: 0 additions & 79 deletions kubeflow/pytorch-job/pytorch-job.libsonnet

This file was deleted.

0 comments on commit 225c8be

Please sign in to comment.