Skip to content

Commit

Permalink
resume and async resume
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed May 21, 2024
1 parent 907705d commit d7266e2
Showing 1 changed file with 52 additions and 31 deletions.
83 changes: 52 additions & 31 deletions metaflow/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,61 +105,82 @@ def __init__(
if profile:
self.env_vars["METAFLOW_PROFILE"] = profile
self.spm = SubprocessManager()
self.api = MetaflowAPI.from_cli(self.flow_file, start)(**kwargs)
self.top_level_kwargs = kwargs
self.api = MetaflowAPI.from_cli(self.flow_file, start)

def __enter__(self):
return self

async def __aenter__(self):
return self

def __get_executing_run(self, tfp_pathspec, command_obj):
try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
stderr_log = open(command_obj.log_files["stderr"]).read()
command = " ".join(command_obj.command)
error_message = "Error executing: '%s':\n" % command
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
raise RuntimeError(error_message) from e

def run(self, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
command = self.api(**self.top_level_kwargs).run(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = self.spm.run_command([sys.executable, *command], env=self.env_vars)
command_obj = self.spm.get(pid)

try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
stderr_log = open(command_obj.log_files["stderr"]).read()
command = " ".join(command_obj.command)
error_message = "Error executing: '%s':\n" % command
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
raise RuntimeError(error_message) from e
return self.__get_executing_run(tfp_pathspec, command_obj)

def resume(self, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api(**self.top_level_kwargs).resume(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = self.spm.run_command([sys.executable, *command], env=self.env_vars)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_pathspec, command_obj)

async def async_run(self, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
command = self.api(**self.top_level_kwargs).run(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = await self.spm.async_run_command(
[sys.executable, *command], env=self.env_vars
)
command_obj = self.spm.get(pid)

return self.__get_executing_run(tfp_pathspec, command_obj)

async def async_resume(self, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
command = self.api(**self.top_level_kwargs).resume(
pathspec_file=tfp_pathspec.name, **kwargs
)

pid = await self.spm.async_run_command(
[sys.executable, *command], env=self.env_vars
)
command_obj = self.spm.get(pid)

try:
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
run_object = Run(pathspec, _namespace_check=False)
return ExecutingRun(self, command_obj, run_object)
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
stderr_log = open(command_obj.log_files["stderr"]).read()
command = " ".join(command_obj.command)
error_message = "Error executing: '%s':\n" % command
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
raise RuntimeError(error_message) from e
return self.__get_executing_run(tfp_pathspec, command_obj)

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
Expand Down

0 comments on commit d7266e2

Please sign in to comment.