diff --git a/metaflow/sidecar/sidecar.py b/metaflow/sidecar/sidecar.py index d5c511bf80b..8358b356904 100644 --- a/metaflow/sidecar/sidecar.py +++ b/metaflow/sidecar/sidecar.py @@ -13,14 +13,24 @@ def __init__(self, sidecar_type): if t is not None and t.get_worker() is not None: self._has_valid_worker = True self.sidecar_process = None + # Whether to send msg in a thread-safe fashion. + self._threadsafe_send_enabled = False def start(self): if not self.is_active and self._has_valid_worker: self.sidecar_process = SidecarSubProcess(self._sidecar_type) + def enable_threadsafe_send(self): + self._threadsafe_send_enabled = True + + def disable_threadsafe_send(self): + self._threadsafe_send_enabled = False + def send(self, msg): if self.is_active: - self.sidecar_process.send(msg) + self.sidecar_process.send( + msg, thread_safe_send=self._threadsafe_send_enabled + ) def terminate(self): if self.is_active: diff --git a/metaflow/sidecar/sidecar_subprocess.py b/metaflow/sidecar/sidecar_subprocess.py index e58b049f8d1..34d4cc30815 100644 --- a/metaflow/sidecar/sidecar_subprocess.py +++ b/metaflow/sidecar/sidecar_subprocess.py @@ -25,6 +25,10 @@ except: blockingError = OSError +import threading + +lock = threading.Lock() + class PipeUnavailableError(Exception): """raised when unable to write to pipe given allotted time""" @@ -113,16 +117,16 @@ def kill(self): except: pass - def send(self, msg, retries=3): + def send(self, msg, retries=3, thread_safe_send=False): if msg.msg_type == MessageTypes.MUST_SEND: # If this is a must-send message, we treat it a bit differently. A must-send # message has to be properly sent before any of the other best effort messages. self._cached_mustsend = msg.payload self._send_mustsend_remaining_tries = MUST_SEND_RETRY_TIMES - self._send_mustsend(retries) + self._send_mustsend(retries, thread_safe_send) else: # Ignore return code for send. - self._send_internal(msg, retries=retries) + self._send_internal(msg, retries=retries, thread_safe_send=thread_safe_send) def _start_subprocess(self, cmdline): for _ in range(3): @@ -145,7 +149,7 @@ def _start_subprocess(self, cmdline): self._logger("Unknown popen error: %s" % repr(e)) break - def _send_internal(self, msg, retries=3): + def _send_internal(self, msg, retries=3, thread_safe_send=False): if self._process is None: return False try: @@ -157,13 +161,13 @@ def _send_internal(self, msg, retries=3): # restart sidecar so use the PipeUnavailableError caught below raise PipeUnavailableError() elif self._send_mustsend_remaining_tries > 0: - self._send_mustsend() + self._send_mustsend(thread_safe_send=thread_safe_send) if self._send_mustsend_remaining_tries == 0: - self._emit_msg(msg) + self._emit_msg(msg, thread_safe_send) self._prev_message_error = False return True else: - self._emit_msg(msg) + self._emit_msg(msg, thread_safe_send) self._prev_message_error = False return True return False @@ -184,14 +188,14 @@ def _send_internal(self, msg, retries=3): self._prev_message_error = True if retries > 0: self._logger("Retrying msg send to sidecar (due to %s)" % repr(ex)) - return self._send_internal(msg, retries - 1) + return self._send_internal(msg, retries - 1, thread_safe_send) else: self._logger( "Error sending log message (exhausted retries): %s" % repr(ex) ) return False - def _send_mustsend(self, retries=3): + def _send_mustsend(self, retries=3, thread_safe_send=False): if ( self._cached_mustsend is not None and self._send_mustsend_remaining_tries > 0 @@ -199,7 +203,9 @@ def _send_mustsend(self, retries=3): # If we don't succeed in sending the must-send, we will try again # next time. if self._send_internal( - Message(MessageTypes.MUST_SEND, self._cached_mustsend), retries + Message(MessageTypes.MUST_SEND, self._cached_mustsend), + retries, + thread_safe_send, ): self._cached_mustsend = None self._send_mustsend_remaining_tries = 0 @@ -211,14 +217,7 @@ def _send_mustsend(self, retries=3): self._send_mustsend_remaining_tries = -1 return False - def _emit_msg(self, msg): - # If the previous message had an error, we want to prepend a "\n" to this message - # to maximize the chance of this message being valid (for example, if the - # previous message only partially sent for whatever reason, we want to "clear" it) - msg = msg.serialize() - if self._prev_message_error: - msg = "\n" + msg - msg_ser = msg.encode("utf-8") + def _write_bytes(self, msg_ser): written_bytes = 0 while written_bytes < len(msg_ser): # self._logger("Sent %d out of %d bytes" % (written_bytes, len(msg_ser))) @@ -235,6 +234,23 @@ def _emit_msg(self, msg): # sidecar is disabled, ignore all messages break + def _emit_msg(self, msg, thread_safe_send=False): + # If the previous message had an error, we want to prepend a "\n" to this message + # to maximize the chance of this message being valid (for example, if the + # previous message only partially sent for whatever reason, we want to "clear" it) + msg = msg.serialize() + if self._prev_message_error: + msg = "\n" + msg + msg_ser = msg.encode("utf-8") + + # If threadsafe send is enabled, we will use a lock to ensure that only one thread + # can send a message at a time. This is to avoid interleaving of messages. + if thread_safe_send: + with lock: + self._write_bytes(msg_ser) + else: + self._write_bytes(msg_ser) + def _logger(self, msg): if debug.sidecar: print("[sidecar:%s] %s" % (self._worker_type, msg), file=sys.stderr)