Skip to content

Commit

Permalink
clean up pr for review
Browse files Browse the repository at this point in the history
  • Loading branch information
darinyu committed Feb 27, 2024
1 parent a034789 commit 346cfe0
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 289 deletions.
6 changes: 1 addition & 5 deletions metaflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,6 @@ def resume(
run_id_file=None,
resume_identifier=None,
):
import time

start_time = time.time()
before_run(obj, tags, decospecs + obj.environment.decospecs())

if origin_run_id is None:
Expand Down Expand Up @@ -809,7 +806,6 @@ def resume(
runtime.clone_original_run()
else:
runtime.execute()
print(f"Resume finished in {time.time() - start_time:.2f} secs")


@tracing.cli_entrypoint("cli/run")
Expand Down Expand Up @@ -1037,7 +1033,7 @@ def start(
ctx.obj.monitor.start()

ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0](
ctx.obj.environment, ctx.obj.flow.name, ctx.obj.event_logger, ctx.obj.monitor
ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor
)

ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == datastore][0]
Expand Down
22 changes: 0 additions & 22 deletions metaflow/clone_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,14 @@ def clone_task_helper(
origin_ds_set=None,
attempt_id=0,
):
data = []
print(
f"Cloning task from {flow_name}/{clone_run_id}/{step_name}/{task_id} to {flow_name}/{run_id}/{step_name}/{task_id}"
)
start_time = time.time()
# 1. initialize output datastore
output = flow_datastore.get_task_datastore(
run_id, step_name, task_id, attempt=attempt_id, mode="w"
)
output.init_task()
# data.extend(output.init_task_iter())
end_time = time.time()
print(
f"Cloning task {flow_name}/{run_id}/{step_name}/{task_id}, step 1, finished in {end_time - start_time:.2f} secs"
)

origin_run_id, origin_step_name, origin_task_id = clone_run_id, step_name, task_id
# 2. initialize origin datastore
Expand All @@ -42,10 +35,6 @@ def clone_task_helper(
)
metadata_tags = ["attempt_id:{0}".format(attempt_id)]
output.clone(origin)
end_time = time.time()
# print(
# f"Cloning task {flow_name}/{run_id}/{step_name}/{task_id}, step 2.1, finished in {end_time - start_time:.2f} secs"
# )
_ = metadata_service.register_task_id(
run_id,
step_name,
Expand Down Expand Up @@ -77,15 +66,4 @@ def clone_task_helper(
),
],
)
end_time = time.time()
# print(
# f"Cloning task {flow_name}/{run_id}/{step_name}/{task_id}, step 2.2, finished in {end_time - start_time:.2f} secs"
# )
# output.done(write_to_storage=False)
output.done()
# data.extend(output.done_iter())
end_time = time.time()
print(
f"Cloning task {flow_name}/{run_id}/{step_name}/{task_id} finished in {end_time - start_time:.2f} secs"
)
return data
43 changes: 0 additions & 43 deletions metaflow/datastore/task_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,6 @@ def init_task(self):
"""
self.save_metadata({self.METADATA_ATTEMPT_SUFFIX: {"time": time.time()}})

def init_task_iter(self):
return self.get_save_file_iter(
{self.METADATA_ATTEMPT_SUFFIX: {"time": time.time()}}
)

@only_if_not_done
@require_mode("w")
def save_artifacts(self, artifacts_iter, force_v4=False, len_hint=0):
Expand Down Expand Up @@ -588,21 +583,6 @@ def is_none(self, name):
# Slow path since this has to get the object from the datastore
return self.get(name) is None

def done_iter(self):
return self.get_save_file_iter(
{
self.METADATA_DATA_SUFFIX: {
"datastore": self.TYPE,
"version": "1.0",
"attempt": self._attempt,
"python_version": sys.version,
"objects": self._objects,
"info": self._info,
},
self.METADATA_DONE_SUFFIX: "",
}
)

@only_if_not_done
@require_mode("w")
def done(self, write_to_storage=True):
Expand Down Expand Up @@ -913,29 +893,6 @@ def blob_iter():

self._storage_impl.save_bytes(blob_iter(), overwrite=allow_overwrite)

def get_save_file_iter(self, contents, add_attempt=True):
def convert(contents):
return {k: json.dumps(v).encode("utf-8") for k, v in contents.items()}

result = []
for name, value in convert(contents).items():
if add_attempt:
path = self._storage_impl.path_join(
self._path, self._metadata_name_for_attempt(name)
)
else:
path = self._storage_impl.path_join(self._path, name)
if isinstance(value, (RawIOBase, BufferedIOBase)) and value.readable():
result.append((path, value))
elif is_stringish(value):
result.append((path, to_fileobj(value)))
else:
raise DataException(
"Metadata1 '%s' for task '%s' has an invalid type: %s"
% (name, self._path, type(value))
)
return result

def _load_file(self, names, add_attempt=True):
"""
Loads files from the TaskDataStore directory. These can be metadata,
Expand Down
4 changes: 2 additions & 2 deletions metaflow/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,11 +727,11 @@ def _reconstruct_metadata_for_attempt(all_metadata, attempt_id):

return post_filter

def __init__(self, environment, flow_name, event_logger, monitor):
def __init__(self, environment, flow, event_logger, monitor):
self._task_id_seq = -1
self.sticky_tags = set()
self.sticky_sys_tags = set()
self._flow_name = flow_name
self._flow_name = flow.name
self._event_logger = event_logger
self._monitor = monitor
self._environment = environment
Expand Down
4 changes: 2 additions & 2 deletions metaflow/plugins/metadata/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
class LocalMetadataProvider(MetadataProvider):
TYPE = "local"

def __init__(self, environment, flow_name, event_logger, monitor):
def __init__(self, environment, flow, event_logger, monitor):
super(LocalMetadataProvider, self).__init__(
environment, flow_name, event_logger, monitor
environment, flow.name, event_logger, monitor
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions metaflow/plugins/metadata/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ class ServiceMetadataProvider(MetadataProvider):
_supports_attempt_gets = None
_supports_tag_mutation = None

def __init__(self, environment, flow_name, event_logger, monitor):
def __init__(self, environment, flow, event_logger, monitor):
super(ServiceMetadataProvider, self).__init__(
environment, flow_name, event_logger, monitor
environment, flow.name, event_logger, monitor
)
self.url_task_template = os.path.join(
SERVICE_URL,
Expand Down
91 changes: 15 additions & 76 deletions metaflow/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datetime import datetime
from io import BytesIO
from functools import partial
from concurrent import futures

from metaflow.datastore.exceptions import DataException

Expand Down Expand Up @@ -241,7 +242,7 @@ def clone_task(self, step_name, task_id):
f"Cloning task {self._flow.name}/{self._run_id}/{step_name}/{task_id}",
system_msg=True,
)
return clone_task_helper(
clone_task_helper(
self._flow.name,
self._clone_run_id,
self._run_id,
Expand All @@ -261,45 +262,24 @@ def clone_original_run(self):
self._logger(skip_reason, system_msg=True)
return
self._metadata.start_run_heartbeat(self._flow.name, self._run_id)
from metaflow import Run
self._logger(
f"Start cloning original run: {self._flow.name}/{self._clone_run_id}",
system_msg=True,
)

run = Run(f"{self._flow.name}/{self._clone_run_id}")
self._logger("Start cloning original run: %s" % (run), system_msg=True)
inputs = []
for step in run:
for task in step:
_, _, step_name, task_id = task.pathspec.split("/")
if task.successful:
# self.clone_task(step_name, task_id)
inputs.append((step_name, task_id))

inputs2 = []
if self._origin_ds_set:
for k, v in self._origin_ds_set.pathspec_cache.items():
_, step_name, task_id = k.split("/")
if v["_task_ok"] and step_name != "_parameters":
inputs2.append((step_name, task_id))
print("inputs2: ", inputs2)

self._logger("Finish up non-s3 work: %s" % (run), system_msg=True)
from concurrent import futures
import time

with futures.ThreadPoolExecutor(max_workers=64) as executor:

for task_ds in self._origin_ds_set:
_, step_name, task_id = task_ds.pathspec.split("/")
if task_ds["_task_ok"] and step_name != "_parameters":
inputs.append((step_name, task_id))

with futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor:
all_tasks = [
executor.submit(self.clone_task, step_name, task_id)
for (step_name, task_id) in inputs
]
res, _ = futures.wait(all_tasks)
results = []
for future in res:
results.extend(future.result())
# print("final results: ", results)
# print("copying files to s3 started")

# start_time = time.time()
# self._flow_datastore._storage_impl.save_bytes(results, overwrite=True, len_hint=50)
# print(f"copying files takes {time.time() - start_time:.2f} secs")
_, _ = futures.wait(all_tasks)
self._logger("Cloning original run is done", system_msg=True)
self._params_task.mark_resume_done()

Expand Down Expand Up @@ -1371,12 +1351,7 @@ def __str__(self):
class Worker(object):
def __init__(self, task, max_logs_size):
self.task = task
if self.task.is_cloned and self.task.clone_origin:
print("launch clone?")
self._proc = self._launch_clone()
else:
print("launch original.")
self._proc = self._launch()
self._proc = self._launch()

if task.retries > task.user_code_retries:
self.task.log(
Expand Down Expand Up @@ -1406,24 +1381,6 @@ def __init__(self, task, max_logs_size):
# noticed by the runtime and queried for its state (whether or
# not it is properly shut down)

def _launch_clone(self):
env = dict(os.environ)
env["PYTHONUNBUFFERED"] = "x"
cmd = [
"python",
"-c",
f"from metaflow.util import print_hello; print_hello('{self.task.flow_name}', '{self.task.clone_run_id}', '{self.task.run_id}', '{self.task.step}', '{self.task.task_id}', '{self.task._flow_datastore.default_storage_impl.datastore_root}')",
]
print("running cmd!", " ".join(cmd))
return subprocess.Popen(
cmd,
env=env,
bufsize=1,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
)

def _launch(self):
args = CLIArgs(self.task)
env = dict(os.environ)
Expand Down Expand Up @@ -1597,21 +1554,3 @@ def terminate(self):

def __str__(self):
return "Worker[%d]: %s" % (self._proc.pid, self.task.path)


class Test:
def helper(self, a, b):
print(f"{a}+{b}: ", a + b)
return a + b

def process(self):
inputs = []
for i in range(3):
inputs.append((i, i * 2))
from concurrent import futures

with futures.ThreadPoolExecutor(max_workers=64) as executor:
all_tasks = [executor.submit(self.helper, x, y) for (x, y) in inputs]
res, _ = futures.wait(all_tasks)
results = [future.result() for future in res]
print(results)
Loading

0 comments on commit 346cfe0

Please sign in to comment.