Skip to content

Commit

Permalink
run: move monitor logic out of run
Browse files Browse the repository at this point in the history
  • Loading branch information
pared committed Mar 1, 2021
1 parent a53fffe commit 5d7a401
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 175 deletions.
2 changes: 1 addition & 1 deletion dvc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def make_checkpoint():
from time import sleep

from dvc.env import DVC_CHECKPOINT, DVC_ROOT
from dvc.stage.run import CheckpointTask
from dvc.stage.monitor import CheckpointTask

if os.getenv(DVC_CHECKPOINT) is None:
return
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc.env import DVCLIVE_RESUME
from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.stage.run import CheckpointKilledError
from dvc.stage.monitor import CheckpointKilledError
from dvc.utils import relpath

from .base import (
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from dvc.scm import SCM
from dvc.stage import PipelineStage
from dvc.stage.run import CheckpointKilledError
from dvc.stage.monitor import CheckpointKilledError
from dvc.stage.serialize import to_lockfile
from dvc.utils import dict_sha256
from dvc.utils.fs import remove
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _reproduce_stages(
_repro_callback, checkpoint_func, unchanged
)

from dvc.stage.run import CheckpointKilledError
from dvc.stage.monitor import CheckpointKilledError

try:
ret = _reproduce_stage(stage, **kwargs)
Expand Down
150 changes: 150 additions & 0 deletions dvc/stage/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import functools
import logging
import os
import subprocess
import threading
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, List

from dvc.repo.live import create_summary
from dvc.stage.decorators import relock_repo
from dvc.stage.exceptions import StageCmdFailedError

if TYPE_CHECKING:
from dvc.output import BaseOutput
from dvc.stage import Stage


logger = logging.getLogger(__name__)


class CheckpointKilledError(StageCmdFailedError):
pass


class LiveKilledError(StageCmdFailedError):
pass


@dataclass
class MonitorTask:
stage: "Stage"
execute: Callable
proc: subprocess.Popen
done: threading.Event = threading.Event()
killed: threading.Event = threading.Event()

@property
def name(self) -> str:
raise NotImplementedError

@property
def SIGNAL_FILE(self) -> str:
raise NotImplementedError

@property
def error_cls(self) -> type:
raise NotImplementedError

@property
def signal_path(self) -> str:
return os.path.join(self.stage.repo.tmp_dir, self.SIGNAL_FILE)

def after_run(self):
pass


class CheckpointTask(MonitorTask):
name = "checkpoint"
SIGNAL_FILE = "DVC_CHECKPOINT"
error_cls = CheckpointKilledError

def __init__(
self, stage: "Stage", callback_func: Callable, proc: subprocess.Popen
):
super().__init__(
stage,
functools.partial(
CheckpointTask._run_callback, stage, callback_func
),
proc,
)

@staticmethod
@relock_repo
def _run_callback(stage, callback_func):
stage.save(allow_missing=True)
stage.commit(allow_missing=True)
logger.debug("Running checkpoint callback for stage '%s'", stage)
callback_func()


class LiveTask(MonitorTask):
name = "live"
SIGNAL_FILE = "DVC_LIVE"
error_cls = LiveKilledError

def __init__(
self, stage: "Stage", out: "BaseOutput", proc: subprocess.Popen
):
super().__init__(stage, functools.partial(create_summary, out), proc)

def after_run(self):
# make sure summary is prepared for all the data
self.execute()


class Monitor:
AWAIT: float = 1.0

def __init__(self, tasks: List[MonitorTask]):
self.done = threading.Event()
self.tasks = tasks
self.monitor_thread = threading.Thread(
target=Monitor._loop, args=(self.tasks, self.done,),
)

def __enter__(self):
self.monitor_thread.start()

def __exit__(self, exc_type, exc_val, exc_tb):
self.done.set()
self.monitor_thread.join()
for t in self.tasks:
t.after_run()

@staticmethod
def kill(proc):
if os.name == "nt":
return Monitor._kill_nt(proc)
proc.terminate()
proc.wait()

@staticmethod
def _kill_nt(proc):
# windows stages are spawned with shell=True, proc is the shell process
# and not the actual stage process - we have to kill the entire tree
subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)])

@staticmethod
def _loop(tasks: List[MonitorTask], done: threading.Event):
while True:
for task in tasks:
if os.path.exists(task.signal_path):
try:
task.execute()
except Exception: # pylint: disable=broad-except
logger.exception(
"Error running '%s' task, '%s' will be aborted",
task.name,
task.stage,
)
Monitor.kill(task.proc)
task.killed.set()
finally:
logger.debug(
"Removing signal file for '%s' task", task.name
)
os.remove(task.signal_path)
if done.wait(Monitor.AWAIT):
return
148 changes: 7 additions & 141 deletions dvc/stage/run.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,20 @@
import functools
import logging
import os
import signal
import subprocess
import threading
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, List

if TYPE_CHECKING:
from dvc.stage import Stage
from dvc.output import BaseOutput

from funcy import first

from dvc.repo.live import create_summary
from dvc.stage.monitor import Monitor
from dvc.utils import fix_env

from .decorators import relock_repo, unlocked_repo
from .decorators import unlocked_repo
from .exceptions import StageCmdFailedError

logger = logging.getLogger(__name__)


class CheckpointKilledError(StageCmdFailedError):
pass


class LiveKilledError(StageCmdFailedError):
pass


def _make_cmd(executable, cmd):
if executable is None:
return cmd
Expand Down Expand Up @@ -124,12 +109,17 @@ def _run(stage, executable, cmd, checkpoint_func, **kwargs):


def _get_monitor_tasks(stage, checkpoint_func, proc):

result = []
if checkpoint_func:
from .monitor import CheckpointTask

result.append(CheckpointTask(stage, checkpoint_func, proc))

live = first((o for o in stage.outs if (o.live and o.live["html"])))
if live:
from .monitor import LiveTask

result.append(LiveTask(stage, live, proc))

return result
Expand Down Expand Up @@ -168,127 +158,3 @@ def run_stage(

run = cmd_run if dry else unlocked_repo(cmd_run)
run(stage, dry=dry, checkpoint_func=checkpoint_func, run_env=run_env)


@dataclass
class MonitorTask:
stage: "Stage"
execute: Callable
proc: subprocess.Popen
done: threading.Event = threading.Event()
killed: threading.Event = threading.Event()

@property
def name(self) -> str:
raise NotImplementedError

@property
def SIGNAL_FILE(self) -> str:
raise NotImplementedError

@property
def error_cls(self) -> type:
raise NotImplementedError

@property
def signal_path(self) -> str:
return os.path.join(self.stage.repo.tmp_dir, self.SIGNAL_FILE)

def after_run(self):
pass


class CheckpointTask(MonitorTask):
name = "checkpoint"
SIGNAL_FILE = "DVC_CHECKPOINT"
error_cls = CheckpointKilledError

def __init__(
self, stage: "Stage", callback_func: Callable, proc: subprocess.Popen
):
super().__init__(
stage,
functools.partial(
CheckpointTask._run_callback, stage, callback_func
),
proc,
)

@staticmethod
@relock_repo
def _run_callback(stage, callback_func):
stage.save(allow_missing=True)
stage.commit(allow_missing=True)
logger.debug("Running checkpoint callback for stage '%s'", stage)
callback_func()


class LiveTask(MonitorTask):
name = "live"
SIGNAL_FILE = "DVC_LIVE"
error_cls = LiveKilledError

def __init__(
self, stage: "Stage", out: "BaseOutput", proc: subprocess.Popen
):
super().__init__(stage, functools.partial(create_summary, out), proc)

def after_run(self):
# make sure summary is prepared for all the data
self.execute()


class Monitor:
AWAIT: float = 1.0

def __init__(self, tasks: List[MonitorTask]):
self.done = threading.Event()
self.tasks = tasks
self.monitor_thread = threading.Thread(
target=Monitor._loop, args=(self.tasks, self.done,),
)

def __enter__(self):
self.monitor_thread.start()

def __exit__(self, exc_type, exc_val, exc_tb):
self.done.set()
self.monitor_thread.join()
for t in self.tasks:
t.after_run()

@staticmethod
def kill(proc):
if os.name == "nt":
return Monitor._kill_nt(proc)
proc.terminate()
proc.wait()

@staticmethod
def _kill_nt(proc):
# windows stages are spawned with shell=True, proc is the shell process
# and not the actual stage process - we have to kill the entire tree
subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)])

@staticmethod
def _loop(tasks: List[MonitorTask], done: threading.Event):
while True:
for task in tasks:
if os.path.exists(task.signal_path):
try:
task.execute()
except Exception: # pylint: disable=broad-except
logger.exception(
"Error running '%s' task, '%s' will be aborted",
task.name,
task.stage,
)
Monitor.kill(task.proc)
task.killed.set()
finally:
logger.debug(
"Removing signal file for '%s' task", task.name
)
os.remove(task.signal_path)
if done.wait(Monitor.AWAIT):
return
Loading

0 comments on commit 5d7a401

Please sign in to comment.