diff --git a/wayfire/ipc.py b/wayfire/ipc.py index fa90335..2cc474c 100644 --- a/wayfire/ipc.py +++ b/wayfire/ipc.py @@ -9,21 +9,22 @@ class WayfireSocket: def __init__(self, socket_name: str | None=None, allow_manual_search=False): if socket_name is None: - socket_name = os.getenv("WAYFIRE_SOCKET") + env_socket = os.getenv("WAYFIRE_SOCKET") + socket_name = env_socket.strip() if env_socket else None self.socket_name = None self.pending_events = [] self.timeout = 3 - if socket_name is None and allow_manual_search: - # the last item is the most recent socket file - socket_list = sorted( - [ - os.path.join("/tmp", i) - for i in os.listdir("/tmp") - if "wayfire-wayland" in i - ] - ) + if socket_name is not None: + try: + self.connect_client(socket_name) + self.socket_name = socket_name + except Exception: + socket_name = None + + if self.socket_name is None and allow_manual_search: + socket_list = self._find_candidate_sockets() for candidate in socket_list: try: @@ -33,12 +34,38 @@ def __init__(self, socket_name: str | None=None, allow_manual_search=False): except Exception: pass - elif socket_name is not None: - self.connect_client(socket_name) - self.socket_name = socket_name - if self.socket_name is None: - raise Exception("Failed to find a suitable Wayfire socket!") + if not allow_manual_search: + raise Exception( + "Failed to find a suitable Wayfire socket! " + "Try allow_manual_search=True to look in standard locations." + ) + else: + raise Exception( + "Failed to find a suitable Wayfire socket! " + "Manual search was performed but found no working socket. " + ) + + def _find_candidate_sockets(self) -> List[str]: + socket_list = [] + + runtime_dir = os.getenv("XDG_RUNTIME_DIR") + if runtime_dir is not None and os.path.isdir(runtime_dir): + for item in os.listdir(runtime_dir): + if item.startswith("wayfire-wayland-") and item.endswith(".socket"): + socket_list.append(os.path.join(runtime_dir, item)) + socket_list.sort() + + tmp_sockets = sorted( + [ + os.path.join("/tmp", i) + for i in os.listdir("/tmp") + if i.startswith("wayfire-wayland-") and i.endswith(".socket") + ] + ) + socket_list.extend(tmp_sockets) + + return socket_list def connect_client(self, socket_name): self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)