Skip to content

Commit

Permalink
Support restarting training job (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
hougangliu authored and k8s-ci-robot committed Oct 30, 2019
1 parent 606736d commit 1608d28
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 16 deletions.
8 changes: 5 additions & 3 deletions cmd/metricscollector/v1alpha3/file-metricscollector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"context"
"flag"
"os"
"path/filepath"
"strings"

"github.com/hpcloud/tail"
Expand Down Expand Up @@ -84,9 +85,10 @@ func main() {

go printMetricsFile(*metricsFileName)
wopts := common.WaitPidsOpts{
PollInterval: *pollInterval,
Timeout: *timeout,
WaitAll: *waitAll,
PollInterval: *pollInterval,
Timeout: *timeout,
WaitAll: *waitAll,
CompletedMarkedDirPath: filepath.Dir(*metricsFileName),
}
if err := common.Wait(wopts); err != nil {
klog.Fatalf("Failed to wait for worker container: %v", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_options():
raise Exception("Invalid katib manager service address: %s" %
opt.manager_server_addr)

WaitOtherMainProcesses()
WaitOtherMainProcesses(completed_marked_dir=opt.dir_path)

mc = MetricsCollector(opt.metric_names.split(','))
observation_log = mc.parse_file(opt.dir_path)
Expand Down
2 changes: 2 additions & 0 deletions pkg/metricscollector/v1alpha3/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ const (

MetricCollectorContainerName = "metrics-collector"
MetricLoggerCollectorContainerName = "metrics-logger-and-collector"

TrainingCompleted = "completed"
)
20 changes: 17 additions & 3 deletions pkg/metricscollector/v1alpha3/common/pns.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ package common

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"time"

gops "github.com/mitchellh/go-ps"
Expand All @@ -27,9 +30,10 @@ import (
var ErrWaitPidTimeout = fmt.Errorf("Timed out waiting for PID to complete")

type WaitPidsOpts struct {
PollInterval time.Duration
Timeout time.Duration
WaitAll bool
PollInterval time.Duration
Timeout time.Duration
WaitAll bool
CompletedMarkedDirPath string
}

func Wait(opts WaitPidsOpts) error {
Expand Down Expand Up @@ -95,6 +99,16 @@ func WaitPIDS(pids []int, opts ...WaitPidsOpts) error {
_, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
if opts[0].CompletedMarkedDirPath != "" {
markFile := filepath.Join(opts[0].CompletedMarkedDirPath, fmt.Sprintf("%d.pid", pid))
if data, err := ioutil.ReadFile(markFile); err != nil {
return fmt.Errorf("Process %d hadn't completed: %v", pid, err)
} else {
if strings.TrimSpace(string(data)) != TrainingCompleted {
return fmt.Errorf("Process %d hadn't completed", pid)
}
}
}
if waitAll {
finishedPids = append(finishedPids, pid)
if len(finishedPids) == len(pids) {
Expand Down
12 changes: 9 additions & 3 deletions pkg/metricscollector/v1alpha3/common/pns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def GetOtherMainProcesses():
pids.add(pid)
return pids

def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False):
def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False, completed_marked_dir=""):
start = 0
pids = set(pids)
if poll_interval_seconds <= 0:
Expand All @@ -26,6 +26,12 @@ def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False
if os.path.isdir(path):
continue
else:
if completed_marked_dir:
mark_file = os.path.join(completed_marked_dir, "%d.pid" % pid)
with open(mark_file) as file_obj:
contents = file_obj.read()
if contents.strip() != "completed":
raise Exception("Pid %d hadn't completed" % pid)
if is_wait_all:
stop_pids.add(pid)
else:
Expand All @@ -35,5 +41,5 @@ def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False
time.sleep(poll_interval_seconds)
start = start + poll_interval_seconds

def WaitOtherMainProcesses(poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False):
return WaitPIDs(GetOtherMainProcesses(), poll_interval_seconds, timeout_seconds, is_wait_all)
def WaitOtherMainProcesses(poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False, completed_marked_dir=""):
return WaitPIDs(GetOtherMainProcesses(), poll_interval_seconds, timeout_seconds, is_wait_all, completed_marked_dir)
23 changes: 17 additions & 6 deletions pkg/webhook/v1alpha3/pod/inject_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (s *sidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error)

if mountPath, pathKind := getMountPath(trial.Spec.MetricsCollector); mountPath != "" {
if err = wrapWorkerContainer(
mutatedPod, kind, mountPath, trial.Spec.MetricsCollector); err != nil {
mutatedPod, kind, mountPath, pathKind, trial.Spec.MetricsCollector); err != nil {
return nil, err
}
if err = mutateVolume(mutatedPod, kind, mountPath, sidecarContainerName, pathKind); err != nil {
Expand Down Expand Up @@ -229,10 +229,8 @@ func getMountPath(mc common.MetricsCollectorSpec) (string, common.FileSystemKind

func wrapWorkerContainer(
pod *v1.Pod, jobKind, metricsFile string,
pathKind common.FileSystemKind,
mc common.MetricsCollectorSpec) error {
if mc.Collector.Kind != common.StdOutCollector {
return nil
}
index := -1
for i, c := range pod.Spec.Containers {
jobProvider, err := jobv1alpha3.New(jobKind)
Expand All @@ -255,15 +253,28 @@ func wrapWorkerContainer(
if c.Args != nil {
args = append(args, c.Args...)
}
redirectStr := fmt.Sprintf("1>%s 2>&1", metricsFile)
args = append(args, redirectStr)
if mc.Collector.Kind == common.StdOutCollector {
redirectStr := fmt.Sprintf("1>%s 2>&1", metricsFile)
args = append(args, redirectStr)
}
args = append(args, "&&", getMarkCompletedCommand(metricsFile, pathKind))
argsStr := strings.Join(args, " ")
c.Command = command
c.Args = []string{argsStr}
}
return nil
}

func getMarkCompletedCommand(mountPath string, pathKind common.FileSystemKind) string {
dir := mountPath
if pathKind == common.FileKind {
dir = filepath.Dir(mountPath)
}
// $$ is process id in shell
pidFile := filepath.Join(dir, "$$$$.pid")
return fmt.Sprintf("echo %s > %s", mccommon.TrainingCompleted, pidFile)
}

func mutateVolume(pod *v1.Pod, jobKind, mountPath, sidecarContainerName string, pathKind common.FileSystemKind) error {
metricsVol := v1.Volume{
Name: common.MetricsVolume,
Expand Down

0 comments on commit 1608d28

Please sign in to comment.