From d1c14180e702e02312a4dbce7ebd54b17a617808 Mon Sep 17 00:00:00 2001 From: rcelyte Date: Thu, 21 Nov 2024 17:57:52 +0000 Subject: [PATCH] SolarXR IPC --- .../src/main/java/dev/slimevr/VRServer.kt | 26 ++-- .../dev/slimevr/protocol/ProtocolAPI.java | 4 + .../src/main/java/dev/slimevr/desktop/Main.kt | 114 ++++++++--------- .../platform/linux/UnixSocketConnection.java | 111 +++++++++++++++++ .../platform/linux/UnixSocketRpcBridge.java | 115 ++++++++++++++++++ 5 files changed, 290 insertions(+), 80 deletions(-) create mode 100644 server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketConnection.java create mode 100644 server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketRpcBridge.java diff --git a/server/core/src/main/java/dev/slimevr/VRServer.kt b/server/core/src/main/java/dev/slimevr/VRServer.kt index 6ae919fe8e..dc141905c2 100644 --- a/server/core/src/main/java/dev/slimevr/VRServer.kt +++ b/server/core/src/main/java/dev/slimevr/VRServer.kt @@ -36,16 +36,15 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.function.Consumer import kotlin.concurrent.schedule -typealias SteamBridgeProvider = ( +typealias BridgeProvider = ( server: VRServer, computedTrackers: List, -) -> ISteamVRBridge? +) -> Sequence const val SLIMEVR_IDENTIFIER = "dev.slimevr.SlimeVR" class VRServer @JvmOverloads constructor( - driverBridgeProvider: SteamBridgeProvider = { _, _ -> null }, - feederBridgeProvider: (VRServer) -> ISteamVRBridge? = { _ -> null }, + bridgeProvider: BridgeProvider = { _, _ -> sequence {} }, serialHandlerProvider: (VRServer) -> SerialHandler = { _ -> SerialHandlerStub() }, acquireMulticastLock: () -> Any? = { null }, configPath: String, @@ -123,22 +122,11 @@ class VRServer @JvmOverloads constructor( "Sensors UDP server", ) { tracker: Tracker -> registerTracker(tracker) } - // Start bridges for SteamVR and Feeder - val driverBridge = driverBridgeProvider(this, computedTrackers) - if (driverBridge != null) { - tasks.add(Runnable { driverBridge.startBridge() }) - bridges.add(driverBridge) + // Start bridges and WebSocket server + for (bridge in bridgeProvider(this, computedTrackers) + sequenceOf(WebSocketVRBridge(computedTrackers, this))) { + tasks.add(Runnable { bridge.startBridge() }) + bridges.add(bridge) } - val feederBridge = feederBridgeProvider(this) - if (feederBridge != null) { - tasks.add(Runnable { feederBridge.startBridge() }) - bridges.add(feederBridge) - } - - // Create WebSocket server - val wsBridge = WebSocketVRBridge(computedTrackers, this) - tasks.add(Runnable { wsBridge.startBridge() }) - bridges.add(wsBridge) // Initialize OSC handlers vrcOSCHandler = VRCOSCHandler( diff --git a/server/core/src/main/java/dev/slimevr/protocol/ProtocolAPI.java b/server/core/src/main/java/dev/slimevr/protocol/ProtocolAPI.java index 7bb090e643..40e6837b79 100644 --- a/server/core/src/main/java/dev/slimevr/protocol/ProtocolAPI.java +++ b/server/core/src/main/java/dev/slimevr/protocol/ProtocolAPI.java @@ -11,6 +11,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; @@ -31,6 +32,9 @@ public ProtocolAPI(VRServer server) { } public void onMessage(GenericConnection conn, ByteBuffer message) { + if(message.position() != 0) + message = ByteBuffer.wrap(Arrays.copyOfRange(message.array(), message.position(), message.limit())); + MessageBundle messageBundle = MessageBundle.getRootAsMessageBundle(message); try { diff --git a/server/desktop/src/main/java/dev/slimevr/desktop/Main.kt b/server/desktop/src/main/java/dev/slimevr/desktop/Main.kt index cc4662d70c..38dfc10842 100644 --- a/server/desktop/src/main/java/dev/slimevr/desktop/Main.kt +++ b/server/desktop/src/main/java/dev/slimevr/desktop/Main.kt @@ -5,9 +5,11 @@ package dev.slimevr.desktop import dev.slimevr.Keybinding import dev.slimevr.SLIMEVR_IDENTIFIER import dev.slimevr.VRServer +import dev.slimevr.bridge.Bridge import dev.slimevr.bridge.ISteamVRBridge import dev.slimevr.desktop.platform.SteamVRBridge import dev.slimevr.desktop.platform.linux.UnixSocketBridge +import dev.slimevr.desktop.platform.linux.UnixSocketRpcBridge import dev.slimevr.desktop.platform.windows.WindowsNamedPipeBridge import dev.slimevr.desktop.serial.DesktopSerialHandler import dev.slimevr.desktop.tracking.trackers.hid.TrackersHID @@ -118,8 +120,7 @@ fun main(args: Array) { val configDir = resolveConfig() LogManager.info("Using config dir: $configDir") val vrServer = VRServer( - ::provideSteamVRBridge, - ::provideFeederBridge, + ::provideBridges, { _ -> DesktopSerialHandler() }, configPath = configDir, ) @@ -149,90 +150,81 @@ fun main(args: Array) { } } -fun provideSteamVRBridge( +fun provideBridges( server: VRServer, computedTrackers: List, -): ISteamVRBridge? { - val driverBridge: SteamVRBridge? - if (OperatingSystem.currentPlatform == OperatingSystem.WINDOWS) { - // Create named pipe bridge for SteamVR driver - driverBridge = WindowsNamedPipeBridge( - server, - "steamvr", - "SteamVR Driver Bridge", - """\\.\pipe\SlimeVRDriver""", - computedTrackers, - ) - } else if (OperatingSystem.currentPlatform == OperatingSystem.LINUX) { - var linuxBridge: SteamVRBridge? = null - try { - linuxBridge = UnixSocketBridge( +): Sequence = sequence { + when (OperatingSystem.currentPlatform) { + OperatingSystem.WINDOWS -> { + // Create named pipe bridge for SteamVR driver + yield(WindowsNamedPipeBridge( server, "steamvr", "SteamVR Driver Bridge", - Paths.get(OperatingSystem.socketDirectory, "SlimeVRDriver") - .toString(), + """\\.\pipe\SlimeVRDriver""", computedTrackers, - ) - } catch (ex: Exception) { - LogManager.severe( - "Failed to initiate Unix socket, disabling driver bridge...", - ex, - ) - } - driverBridge = linuxBridge - if (driverBridge != null) { - // Close the named socket on shutdown, or otherwise it's not going to get removed - Runtime.getRuntime().addShutdownHook( - Thread { - try { - (driverBridge as? UnixSocketBridge)?.close() - } catch (e: Exception) { - throw RuntimeException(e) - } - }, - ) - } - } else { - driverBridge = null - } + )) - return driverBridge -} - -fun provideFeederBridge( - server: VRServer, -): ISteamVRBridge? { - val feederBridge: SteamVRBridge? - when (OperatingSystem.currentPlatform) { - OperatingSystem.WINDOWS -> { // Create named pipe bridge for SteamVR input - feederBridge = WindowsNamedPipeBridge( + yield(WindowsNamedPipeBridge( server, "steamvr_feeder", "SteamVR Feeder Bridge", """\\.\pipe\SlimeVRInput""", FastList(), - ) + )) } OperatingSystem.LINUX -> { - feederBridge = UnixSocketBridge( + var linuxBridge: SteamVRBridge? = null + try { + linuxBridge = UnixSocketBridge( + server, + "steamvr", + "SteamVR Driver Bridge", + Paths.get(OperatingSystem.socketDirectory, "SlimeVRDriver") + .toString(), + computedTrackers, + ) + } catch (ex: Exception) { + LogManager.severe( + "Failed to initiate Unix socket, disabling driver bridge...", + ex, + ) + } + if (linuxBridge != null) { + // Close the named socket on shutdown, or otherwise it's not going to get removed + Runtime.getRuntime().addShutdownHook( + Thread { + try { + (linuxBridge as? UnixSocketBridge)?.close() + } catch (e: Exception) { + throw RuntimeException(e) + } + }, + ) + yield(linuxBridge); + } + + yield(UnixSocketBridge( server, "steamvr_feeder", "SteamVR Feeder Bridge", Paths.get(OperatingSystem.socketDirectory, "SlimeVRInput") .toString(), FastList(), - ) - } + )) - else -> { - feederBridge = null + yield(UnixSocketRpcBridge( + server, + Paths.get(OperatingSystem.socketDirectory, "SlimeVRRpc") + .toString(), + computedTrackers, + )) } - } - return feederBridge + else -> {} + } } const val CONFIG_FILENAME = "vrconfig.yml" diff --git a/server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketConnection.java b/server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketConnection.java new file mode 100644 index 0000000000..f1b2d1415c --- /dev/null +++ b/server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketConnection.java @@ -0,0 +1,111 @@ +package dev.slimevr.desktop.platform.linux; + +import dev.slimevr.protocol.ConnectionContext; +import dev.slimevr.protocol.GenericConnection; +import io.eiren.util.logging.LogManager; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.SocketChannel; +import java.util.Arrays; +import java.util.UUID; + +public class UnixSocketConnection implements GenericConnection { + public final UUID id; + public final ConnectionContext context; + private final ByteBuffer dst = ByteBuffer.allocate(2048).order(ByteOrder.LITTLE_ENDIAN); + private final SocketChannel channel; + + public UnixSocketConnection(SocketChannel channel) { + this.id = UUID.randomUUID(); + this.context = new ConnectionContext(); + this.channel = channel; + } + + @Override + public UUID getConnectionId() { + return id; + } + + @Override + public ConnectionContext getContext() { + return this.context; + } + + public boolean isConnected() { + return this.channel.isConnected(); + } + + private void resetChannel() { + try { + this.channel.close(); + } catch(IOException e) { + e.printStackTrace(); + } + } + + @Override + public void send(ByteBuffer bytes) { + if (!this.channel.isConnected()) + return; + try { + ByteBuffer[] src = new ByteBuffer[] { + ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN), + bytes.slice(), + }; + src[0].putInt(src[1].remaining() + 4); + src[0].flip(); + synchronized(this) { + while (src[1].hasRemaining()) { + this.channel.write(src); + } + } + } catch (IOException e) { + e.printStackTrace(); + } + } + + public ByteBuffer read() { + if(dst.position() < 4) { + if (!this.channel.isConnected()) + return null; + try { + int result = this.channel.read(dst); + if (result == -1) { + LogManager.info("[SolarXR Bridge] Reached end-of-stream on connection"); + this.resetChannel(); + return null; + } + if (result == 0 || dst.position() < 4) { + return null; + } + } catch(IOException e) { + e.printStackTrace(); + this.resetChannel(); + return null; + } + } + int messageLength = dst.getInt(0); + if (messageLength > 1024) { + LogManager.severe("[SolarXR Bridge] Buffer overflow on socket. Message length: " + messageLength); + this.resetChannel(); + return null; + } + if (dst.position() < messageLength) { + return null; + } + ByteBuffer message = dst.slice(); + message.position(4); + message.limit(messageLength); + return message; + } + + public void next() { + int messageLength = dst.getInt(0); + int originalpos = dst.position(); + dst.position(messageLength); + dst.compact(); + dst.position(originalpos - messageLength); + } +} diff --git a/server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketRpcBridge.java b/server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketRpcBridge.java new file mode 100644 index 0000000000..ee6626f871 --- /dev/null +++ b/server/desktop/src/main/java/dev/slimevr/desktop/platform/linux/UnixSocketRpcBridge.java @@ -0,0 +1,115 @@ +package dev.slimevr.desktop.platform.linux; + +import dev.slimevr.bridge.BridgeThread; +import dev.slimevr.protocol.GenericConnection; +import dev.slimevr.protocol.ProtocolAPI; +import dev.slimevr.tracking.trackers.Tracker; +import dev.slimevr.util.ann.VRServerThread; +import dev.slimevr.VRServer; +import io.eiren.util.logging.LogManager; + +import java.io.File; +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.List; + +public class UnixSocketRpcBridge implements dev.slimevr.bridge.Bridge, dev.slimevr.protocol.ProtocolAPIServer, Runnable, AutoCloseable { + private final Thread runnerThread = new Thread(this, "Named socket thread"); + private final String socketPath; + private final ProtocolAPI protocolAPI; + private final ServerSocketChannel socket; + private final Selector selector; + + public UnixSocketRpcBridge(VRServer server, String socketPath, List shareableTrackers) { + this.socketPath = socketPath; + this.protocolAPI = server.protocolAPI; + File socketFile = new File(socketPath); + if(socketFile.exists()) + throw new RuntimeException(socketPath + " socket already exists."); + socketFile.deleteOnExit(); + try { + socket = ServerSocketChannel.open(StandardProtocolFamily.UNIX); + selector = Selector.open(); + } catch(IOException e) { + e.printStackTrace(); + throw new RuntimeException("Socket open failed."); + } + + server.protocolAPI.registerAPIServer(this); + } + + @VRServerThread + private void disconnected() {} + + @Override + @VRServerThread + public void dataRead() {} + + @Override + @VRServerThread + public void dataWrite() {} + + @Override + @VRServerThread + public void addSharedTracker(Tracker tracker) {} + + @Override + @VRServerThread + public void removeSharedTracker(Tracker tracker) {} + + @Override + @VRServerThread + public void startBridge() { + this.runnerThread.start(); + } + + @Override + @BridgeThread + public void run() { + try { + this.socket.bind(UnixDomainSocketAddress.of(this.socketPath)); + this.socket.configureBlocking(false); + this.socket.register(this.selector, SelectionKey.OP_ACCEPT); + LogManager.info("[SolarXR Bridge] Socket " + this.socketPath + " created"); + while(this.socket.isOpen()) { + this.selector.select(0); + for(SelectionKey key : this.selector.selectedKeys()) { + UnixSocketConnection conn = (UnixSocketConnection)key.attachment(); + if(conn != null) { + for(ByteBuffer message; (message = conn.read()) != null; conn.next()) + this.protocolAPI.onMessage(conn, message); + } else for(SocketChannel channel; (channel = socket.accept()) != null;) { + channel.configureBlocking(false); + channel.register(this.selector, SelectionKey.OP_READ, new UnixSocketConnection(channel)); + LogManager.info("[SolarXR Bridge] Connected to " + channel.getRemoteAddress().toString()); + } + } + } + } catch(IOException e) { + e.printStackTrace(); + } + } + + @Override + public void close() throws Exception { + this.socket.close(); + this.selector.close(); + } + + @Override + public boolean isConnected() { + return this.selector.keys().stream().anyMatch(key -> key.attachment() != null); + } + + @Override + public java.util.stream.Stream getAPIConnections() { + return this.selector.keys().stream().map(key -> (GenericConnection)key.attachment()).filter(conn -> conn != null); + } +}