Skip to content

Commit

Permalink
Also tell the server about uploaded files on file heartbeats (#2312)
Browse files Browse the repository at this point in the history
  • Loading branch information
minyoung authored Jun 29, 2021
1 parent 762dc2b commit 72b49df
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 35 deletions.
6 changes: 6 additions & 0 deletions tests/test_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def test_save_live_existing_file(
internal_sender.publish_files({"files": [("test.txt", "live")]})
stop_backend()
assert len(mock_server.ctx["storage?file=test.txt"]) == 1
assert any(
[
"test.txt" in request_dict.get("uploaded", [])
for request_dict in mock_server.ctx["file_stream"]
]
)


def test_save_live_write_after_policy(
Expand Down
6 changes: 3 additions & 3 deletions wandb/filesync/step_checksum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
RequestCommitArtifact = collections.namedtuple(
"RequestCommitArtifact", ("artifact_id", "finalize", "before_commit", "on_commit")
)
RequestFinish = collections.namedtuple("RequestFinish", ())
RequestFinish = collections.namedtuple("RequestFinish", ("callback"))


class StepChecksum(object):
Expand Down Expand Up @@ -113,7 +113,7 @@ def make_save_fn_with_entry(save_fn, entry):
else:
raise Exception("internal error")

self._output_queue.put(step_upload.RequestFinish())
self._output_queue.put(step_upload.RequestFinish(req.callback))

def start(self):
self._thread.start()
Expand All @@ -122,4 +122,4 @@ def is_alive(self):
return self._thread.is_alive()

def finish(self):
self._request_queue.put(RequestFinish())
self._request_queue.put(RequestFinish(None))
10 changes: 8 additions & 2 deletions wandb/filesync/step_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
RequestCommitArtifact = collections.namedtuple(
"RequestCommitArtifact", ("artifact_id", "finalize", "before_commit", "on_commit")
)
RequestFinish = collections.namedtuple("RequestFinish", ())
RequestFinish = collections.namedtuple("RequestFinish", ("callback"))


class StepUpload(object):
def __init__(self, api, stats, event_queue, max_jobs, silent=False):
def __init__(self, api, stats, event_queue, max_jobs, file_stream, silent=False):
self._api = api
self._stats = stats
self._event_queue = event_queue
self._max_jobs = max_jobs
self._file_stream = file_stream

self._thread = threading.Thread(target=self._thread_body)
self._thread.daemon = True
Expand All @@ -40,9 +41,11 @@ def __init__(self, api, stats, event_queue, max_jobs, silent=False):
def _thread_body(self):
# Wait for event in the queue, and process one by one until a
# finish event is received
finish_callback = None
while True:
event = self._event_queue.get()
if isinstance(event, RequestFinish):
finish_callback = event.callback
break
self._handle_event(event)

Expand All @@ -62,6 +65,8 @@ def _thread_body(self):
self._handle_event(event)
elif not self._running_jobs:
# Queue was empty and no jobs left.
if finish_callback:
finish_callback()
break

def _handle_event(self, event):
Expand Down Expand Up @@ -123,6 +128,7 @@ def _start_upload_job(self, event):
self._event_queue,
self._stats,
self._api,
self._file_stream,
self.silent,
event.save_name,
event.path,
Expand Down
4 changes: 4 additions & 0 deletions wandb/filesync/upload_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
done_queue,
stats,
api,
file_stream,
silent,
save_name,
path,
Expand All @@ -38,6 +39,7 @@ def __init__(
self._done_queue = done_queue
self._stats = stats
self._api = api
self._file_stream = file_stream
self.silent = silent
self.save_name = save_name
self.save_path = self.path = path
Expand All @@ -56,6 +58,8 @@ def run(self):
if self.copied and os.path.isfile(self.save_path):
os.remove(self.save_path)
self._done_queue.put(EventJobDone(self, success))
if success:
self._file_stream.push_success(self.artifact_id, self.save_name)

def push(self):
try:
Expand Down
7 changes: 4 additions & 3 deletions wandb/sdk/internal/file_pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class FilePusher(object):

MAX_UPLOAD_JOBS = 64

def __init__(self, api, silent=False):
def __init__(self, api, file_stream, silent=False):
self._api = api

self._tempdir = tempfile.TemporaryDirectory("wandb")
Expand All @@ -74,6 +74,7 @@ def __init__(self, api, silent=False):
self._stats,
self._event_queue,
self.MAX_UPLOAD_JOBS,
file_stream=file_stream,
silent=silent,
)
self._step_upload.start()
Expand Down Expand Up @@ -172,9 +173,9 @@ def commit_artifact(
)
self._incoming_queue.put(event)

def finish(self):
def finish(self, callback=None):
logger.info("shutting down file pusher")
self._incoming_queue.put(step_checksum.RequestFinish())
self._incoming_queue.put(step_checksum.RequestFinish(callback))

def join(self):
# NOTE: must have called finish before join
Expand Down
33 changes: 30 additions & 3 deletions wandb/sdk/internal/file_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class FileStreamApi(object):

Finish = collections.namedtuple("Finish", ("exitcode"))
Preempting = collections.namedtuple("Preempting", ())
PushSuccess = collections.namedtuple("PushSuccess", ("artifact_id", "save_name"))

HTTP_TIMEOUT = env.get_http_timeout(10)
MAX_ITEMS_PER_PUSH = 10000
Expand Down Expand Up @@ -241,6 +242,7 @@ def _thread_body(self):
posted_data_time = time.time()
posted_anything_time = time.time()
ready_chunks = []
uploaded = set()
finished = None
while finished is None:
items = self._read_queue()
Expand All @@ -251,8 +253,15 @@ def _thread_body(self):
request_with_retry(
self._client.post,
self._endpoint,
json={"complete": False, "preempting": True},
json={
"complete": False,
"preempting": True,
"uploaded": list(uploaded),
},
)
uploaded = set()
elif isinstance(item, self.PushSuccess):
uploaded.add(item.save_name)
else:
# item is Chunk
ready_chunks.append(item)
Expand All @@ -273,14 +282,23 @@ def _thread_body(self):
request_with_retry(
self._client.post,
self._endpoint,
json={"complete": False, "failed": False},
json={
"complete": False,
"failed": False,
"uploaded": list(uploaded),
},
)
)
uploaded = set()
# post the final close message. (item is self.Finish instance now)
request_with_retry(
self._client.post,
self._endpoint,
json={"complete": True, "exitcode": int(finished.exitcode)},
json={
"complete": True,
"exitcode": int(finished.exitcode),
"uploaded": list(uploaded),
},
)

def _thread_except_body(self):
Expand Down Expand Up @@ -353,6 +371,15 @@ def push(self, filename, data):
"""
self._queue.put(Chunk(filename, data))

def push_success(self, artifact_id, save_name):
"""Notification that a file upload has been successfully completed
Arguments:
artifact_id: ID of artifact
save_name: saved name of the uploaded file
"""
self._queue.put(self.PushSuccess(artifact_id, save_name))

def finish(self, exitcode):
"""Cleans up.
Expand Down
30 changes: 21 additions & 9 deletions wandb/sdk/internal/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,44 +303,56 @@ def send_request_defer(self, data):
state = defer.state
logger.info("handle sender defer: {}".format(state))

def transition_state():
state = defer.state + 1
logger.info("send defer: {}".format(state))
self._interface.publish_defer(state)

done = False
if state == defer.BEGIN:
pass
transition_state()
elif state == defer.FLUSH_STATS:
# NOTE: this is handled in handler.py:handle_request_defer()
pass
transition_state()
elif state == defer.FLUSH_TB:
# NOTE: this is handled in handler.py:handle_request_defer()
pass
transition_state()
elif state == defer.FLUSH_SUM:
# NOTE: this is handled in handler.py:handle_request_defer()
pass
transition_state()
elif state == defer.FLUSH_DEBOUNCER:
self.debounce()
transition_state()
elif state == defer.FLUSH_DIR:
if self._dir_watcher:
self._dir_watcher.finish()
self._dir_watcher = None
transition_state()
elif state == defer.FLUSH_FP:
if self._pusher:
self._pusher.finish()
# FilePusher generates some events for FileStreamApi, so we
# need to wait for pusher to finish before going to the next
# state to ensure that filestream gets all the events that we
# want before telling it to finish up
self._pusher.finish(transition_state)
else:
transition_state()
elif state == defer.FLUSH_FS:
if self._fs:
# TODO(jhr): now is a good time to output pending output lines
self._fs.finish(self._exit_code)
self._fs = None
transition_state()
elif state == defer.FLUSH_FINAL:
self._interface.publish_final()
self._interface.publish_footer()
transition_state()
elif state == defer.END:
done = True
else:
raise AssertionError("unknown state")

if not done:
state += 1
logger.info("send defer: {}".format(state))
self._interface.publish_defer(state)
return

exit_result = wandb_internal_pb2.RunExitResult()
Expand Down Expand Up @@ -697,7 +709,7 @@ def _start_run_threads(self, file_dir=None):
email=self._settings.email,
)
self._fs.start()
self._pusher = FilePusher(self._api, silent=self._settings.silent)
self._pusher = FilePusher(self._api, self._fs, silent=self._settings.silent)
self._dir_watcher = DirWatcher(
self._settings, self._api, self._pusher, file_dir
)
Expand Down
7 changes: 4 additions & 3 deletions wandb/sdk_py27/internal/file_pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class FilePusher(object):

MAX_UPLOAD_JOBS = 64

def __init__(self, api, silent=False):
def __init__(self, api, file_stream, silent=False):
self._api = api

self._tempdir = tempfile.TemporaryDirectory("wandb")
Expand All @@ -74,6 +74,7 @@ def __init__(self, api, silent=False):
self._stats,
self._event_queue,
self.MAX_UPLOAD_JOBS,
file_stream=file_stream,
silent=silent,
)
self._step_upload.start()
Expand Down Expand Up @@ -172,9 +173,9 @@ def commit_artifact(
)
self._incoming_queue.put(event)

def finish(self):
def finish(self, callback=None):
logger.info("shutting down file pusher")
self._incoming_queue.put(step_checksum.RequestFinish())
self._incoming_queue.put(step_checksum.RequestFinish(callback))

def join(self):
# NOTE: must have called finish before join
Expand Down
33 changes: 30 additions & 3 deletions wandb/sdk_py27/internal/file_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class FileStreamApi(object):

Finish = collections.namedtuple("Finish", ("exitcode"))
Preempting = collections.namedtuple("Preempting", ())
PushSuccess = collections.namedtuple("PushSuccess", ("artifact_id", "save_name"))

HTTP_TIMEOUT = env.get_http_timeout(10)
MAX_ITEMS_PER_PUSH = 10000
Expand Down Expand Up @@ -241,6 +242,7 @@ def _thread_body(self):
posted_data_time = time.time()
posted_anything_time = time.time()
ready_chunks = []
uploaded = set()
finished = None
while finished is None:
items = self._read_queue()
Expand All @@ -251,8 +253,15 @@ def _thread_body(self):
request_with_retry(
self._client.post,
self._endpoint,
json={"complete": False, "preempting": True},
json={
"complete": False,
"preempting": True,
"uploaded": list(uploaded),
},
)
uploaded = set()
elif isinstance(item, self.PushSuccess):
uploaded.add(item.save_name)
else:
# item is Chunk
ready_chunks.append(item)
Expand All @@ -273,14 +282,23 @@ def _thread_body(self):
request_with_retry(
self._client.post,
self._endpoint,
json={"complete": False, "failed": False},
json={
"complete": False,
"failed": False,
"uploaded": list(uploaded),
},
)
)
uploaded = set()
# post the final close message. (item is self.Finish instance now)
request_with_retry(
self._client.post,
self._endpoint,
json={"complete": True, "exitcode": int(finished.exitcode)},
json={
"complete": True,
"exitcode": int(finished.exitcode),
"uploaded": list(uploaded),
},
)

def _thread_except_body(self):
Expand Down Expand Up @@ -353,6 +371,15 @@ def push(self, filename, data):
"""
self._queue.put(Chunk(filename, data))

def push_success(self, artifact_id, save_name):
"""Notification that a file upload has been successfully completed
Arguments:
artifact_id: ID of artifact
save_name: saved name of the uploaded file
"""
self._queue.put(self.PushSuccess(artifact_id, save_name))

def finish(self, exitcode):
"""Cleans up.
Expand Down
Loading

0 comments on commit 72b49df

Please sign in to comment.