Skip to content

Commit

Permalink
Introduce bolt+unix scheme support (#1586)
Browse files Browse the repository at this point in the history
The new `bolt+unix` scheme allows connecting to Neo4j server over a Unix socket.

Example:
```java
try (var driver = GraphDatabase.driver("bolt+unix:///var/run/neo4j.sock")) {
    // use the driver
    var result = driver.executableQuery("SHOW DATABASES")
            .withConfig(QueryConfig.builder().withDatabase("system").build())
            .execute();
    result.records().forEach(System.out::println);
}
```
  • Loading branch information
injectives authored Nov 21, 2024
1 parent c85bbf5 commit b4f74cc
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 155 deletions.
25 changes: 21 additions & 4 deletions driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import io.netty.channel.local.LocalAddress;
import io.netty.util.concurrent.EventExecutorGroup;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Clock;
import java.util.LinkedHashSet;
import java.util.Set;
Expand Down Expand Up @@ -111,7 +113,6 @@ public final Driver newInstance(
ownsEventLoopGroup = false;
}

var address = new InternalServerAddress(uri);
var routingSettings = new RoutingSettings(config.routingTablePurgeDelayMillis(), new RoutingContext(uri));

EventExecutorGroup eventExecutorGroup = bootstrap.config().group();
Expand All @@ -122,7 +123,6 @@ public final Driver newInstance(
return createDriver(
uri,
securityPlanManager,
address,
bootstrap.group(),
routingSettings,
retryLogic,
Expand All @@ -149,7 +149,6 @@ protected static MetricsProvider getOrCreateMetricsProvider(Config config, Clock
private InternalDriver createDriver(
URI uri,
BoltSecurityPlanManager securityPlanManager,
ServerAddress address,
EventLoopGroup eventLoopGroup,
RoutingSettings routingSettings,
RetryLogic retryLogic,
Expand All @@ -159,11 +158,29 @@ private InternalDriver createDriver(
boolean ownsEventLoopGroup,
Supplier<Rediscovery> rediscoverySupplier) {
BoltConnectionProvider boltConnectionProvider = null;
BoltServerAddress address;
if (Scheme.BOLT_UNIX_URI_SCHEME.equals(uri.getScheme())) {
var path = Path.of(uri.getPath());
if (!Files.exists(path)) {
throw new IllegalArgumentException(String.format("%s does not exist", path));
}
address = new BoltServerAddress(path);
} else {
var port = uri.getPort();
if (port == -1) {
port = InternalServerAddress.DEFAULT_PORT;
}
if (port >= 0 && port <= 65_535) {
address = new BoltServerAddress(uri.getHost(), port);
} else {
throw new IllegalArgumentException("Illegal port: " + port);
}
}
try {
boltConnectionProvider =
createBoltConnectionProvider(uri, config, eventLoopGroup, routingSettings, rediscoverySupplier);
boltConnectionProvider.init(
new BoltServerAddress(address.host(), address.port()),
address,
new RoutingContext(uri),
DriverInfoUtil.boltAgent(),
config.userAgent(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ private static void requireValidPort(int port) {
throw new IllegalArgumentException("Illegal port: " + port);
}

public InternalServerAddress(String address) {
this(uriFrom(address));
}

public InternalServerAddress(URI uri) {
this(hostFrom(uri), portFrom(uri));
}
Expand All @@ -64,43 +60,6 @@ private static RuntimeException invalidAddressFormat(String address) {
return new IllegalArgumentException("Invalid address format `" + address + "`");
}

@SuppressWarnings("DuplicatedCode")
private static URI uriFrom(String address) {
String scheme;
String hostPort;

var schemeSplit = address.split("://");
if (schemeSplit.length == 1) {
// URI can't parse addresses without scheme, prepend fake "bolt://" to reuse the parsing facility
scheme = "bolt://";
hostPort = hostPortFrom(schemeSplit[0]);
} else if (schemeSplit.length == 2) {
scheme = schemeSplit[0] + "://";
hostPort = hostPortFrom(schemeSplit[1]);
} else {
throw invalidAddressFormat(address);
}

return URI.create(scheme + hostPort);
}

private static String hostPortFrom(String address) {
if (address.startsWith("[")) {
// expected to be an IPv6 address like [::1] or [::1]:7687
return address;
}

var containsSingleColon = address.indexOf(":") == address.lastIndexOf(":");
if (containsSingleColon) {
// expected to be an IPv4 address with or without port like 127.0.0.1 or 127.0.0.1:7687
return address;
}

// address contains multiple colons and does not start with '['
// expected to be an IPv6 address without brackets
return "[" + address + "]";
}

@Override
public String toString() {
return String.format("%s:%d", host, port);
Expand Down
2 changes: 2 additions & 0 deletions driver/src/main/java/org/neo4j/driver/internal/Scheme.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class Scheme {
public static final String BOLT_URI_SCHEME = "bolt";
public static final String BOLT_HIGH_TRUST_URI_SCHEME = "bolt+s";
public static final String BOLT_LOW_TRUST_URI_SCHEME = "bolt+ssc";
public static final String BOLT_UNIX_URI_SCHEME = "bolt+unix";
public static final String NEO4J_URI_SCHEME = "neo4j";
public static final String NEO4J_HIGH_TRUST_URI_SCHEME = "neo4j+s";
public static final String NEO4J_LOW_TRUST_URI_SCHEME = "neo4j+ssc";
Expand All @@ -34,6 +35,7 @@ public static void validateScheme(String scheme) {
case BOLT_URI_SCHEME,
BOLT_LOW_TRUST_URI_SCHEME,
BOLT_HIGH_TRUST_URI_SCHEME,
BOLT_UNIX_URI_SCHEME,
NEO4J_URI_SCHEME,
NEO4J_LOW_TRUST_URI_SCHEME,
NEO4J_HIGH_TRUST_URI_SCHEME -> {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static java.util.Objects.requireNonNull;

import java.net.URI;
import java.nio.file.Path;
import java.util.Objects;
import java.util.stream.Stream;

Expand All @@ -35,6 +36,7 @@ public class BoltServerAddress {
// resolved IP address.
protected final int port;
private final String stringValue;
private final Path path;

public BoltServerAddress(String address) {
this(uriFrom(address));
Expand All @@ -55,6 +57,15 @@ public BoltServerAddress(String host, String connectionHost, int port) {
this.stringValue = host.equals(connectionHost)
? String.format("%s:%d", host, port)
: String.format("%s(%s):%d", host, connectionHost, port);
this.path = null;
}

public BoltServerAddress(Path path) {
this.host = path.toString();
this.connectionHost = this.host;
this.port = -1;
this.stringValue = this.host;
this.path = path;
}

@Override
Expand Down Expand Up @@ -91,6 +102,10 @@ public String connectionHost() {
return connectionHost;
}

public Path path() {
return path;
}

/**
* Create a stream of unicast addresses.
* <p>
Expand All @@ -115,7 +130,6 @@ private static int portFrom(URI uri) {
return port == -1 ? DEFAULT_PORT : port;
}

@SuppressWarnings("DuplicatedCode")
private static URI uriFrom(String address) {
String scheme;
String hostPort;
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public final class NettyBoltConnectionProvider implements BoltConnectionProvider
private final LoggingProvider logging;
private final System.Logger log;

private final ConnectionProvider connectionProvider;
private final NettyConnectionProvider connectionProvider;

private BoltServerAddress address;

Expand All @@ -76,7 +76,7 @@ public NettyBoltConnectionProvider(
this.logging = Objects.requireNonNull(logging);
this.log = logging.getLog(getClass());
this.connectionProvider =
ConnectionProviders.netty(eventLoopGroup, clock, domainNameResolver, localAddress, logging);
new NettyConnectionProvider(eventLoopGroup, clock, domainNameResolver, localAddress, logging);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import io.netty.channel.EventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.socket.nio.NioDomainSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.resolver.AddressResolverGroup;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnixDomainSocketAddress;
import java.time.Clock;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
Expand All @@ -52,7 +54,7 @@
import org.neo4j.driver.internal.bolt.basicimpl.messaging.BoltProtocol;
import org.neo4j.driver.internal.bolt.basicimpl.spi.Connection;

public final class NettyConnectionProvider implements ConnectionProvider {
public final class NettyConnectionProvider {
private final EventLoopGroup eventLoopGroup;
private final Clock clock;
private final DomainNameResolver domainNameResolver;
Expand All @@ -75,7 +77,6 @@ public NettyConnectionProvider(
this.logging = logging;
}

@Override
public CompletionStage<Connection> acquireConnection(
BoltServerAddress address,
SecurityPlan securityPlan,
Expand All @@ -90,27 +91,9 @@ public CompletionStage<Connection> acquireConnection(
CompletableFuture<Long> latestAuthMillisFuture,
NotificationConfig notificationConfig,
MetricsListener metricsListener) {
var bootstrap = new Bootstrap();
bootstrap
.group(this.eventLoopGroup)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis)
.channel(localAddress != null ? LocalChannel.class : NioSocketChannel.class)
.resolver(addressResolverGroup)
.handler(new NettyChannelInitializer(address, securityPlan, connectTimeoutMillis, clock, logging));

SocketAddress socketAddress;
if (localAddress == null) {
try {
socketAddress =
new InetSocketAddress(domainNameResolver.resolve(address.connectionHost())[0], address.port());
} catch (Throwable t) {
socketAddress = InetSocketAddress.createUnresolved(address.connectionHost(), address.port());
}
} else {
socketAddress = localAddress;
}

return installChannelConnectedListeners(address, bootstrap.connect(socketAddress), connectTimeoutMillis)
return installChannelConnectedListeners(
address, connect(address, securityPlan, connectTimeoutMillis), connectTimeoutMillis)
.thenCompose(channel -> BoltProtocol.forChannel(channel)
.initializeChannel(
channel,
Expand All @@ -124,6 +107,39 @@ public CompletionStage<Connection> acquireConnection(
.thenApply(channel -> new NetworkConnection(channel, logging));
}

private ChannelFuture connect(BoltServerAddress address, SecurityPlan securityPlan, int connectTimeoutMillis) {
Class<? extends Channel> channelClass;
SocketAddress socketAddress;

if (localAddress != null) {
channelClass = LocalChannel.class;
socketAddress = localAddress;
} else {
if (address.path() != null) {
channelClass = NioDomainSocketChannel.class;
socketAddress = UnixDomainSocketAddress.of(address.path());
} else {
channelClass = NioSocketChannel.class;
try {
socketAddress = new InetSocketAddress(
domainNameResolver.resolve(address.connectionHost())[0], address.port());
} catch (Throwable t) {
socketAddress = InetSocketAddress.createUnresolved(address.connectionHost(), address.port());
}
}
}

var bootstrap = new Bootstrap();
bootstrap
.group(this.eventLoopGroup)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis)
.channel(channelClass)
.resolver(addressResolverGroup)
.handler(new NettyChannelInitializer(address, securityPlan, connectTimeoutMillis, clock, logging));

return bootstrap.connect(socketAddress);
}

private CompletionStage<Channel> installChannelConnectedListeners(
BoltServerAddress address, ChannelFuture channelConnected, int connectTimeoutMillis) {
var pipeline = channelConnected.channel().pipeline();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

class DriverFactoryTest {
private static Stream<String> testUris() {
return Stream.of("bolt://localhost:7687", "neo4j://localhost:7687");
return Stream.of("bolt://localhost:7687", "bolt+unix://localhost:7687", "neo4j://localhost:7687");
}

@ParameterizedTest
Expand Down
Loading

0 comments on commit b4f74cc

Please sign in to comment.