Skip to content

Commit

Permalink
[CONJ-685] permit setting sslMode per host
Browse files Browse the repository at this point in the history
  • Loading branch information
rusher committed Jun 18, 2024
1 parent eca13e7 commit e8349a8
Show file tree
Hide file tree
Showing 16 changed files with 194 additions and 31 deletions.
27 changes: 27 additions & 0 deletions src/main/java/org/mariadb/jdbc/Configuration.java
Original file line number Diff line number Diff line change
Expand Up @@ -2335,6 +2335,19 @@ public Builder addHost(String host, int port) {
return this;
}

/**
* Add Host to possible addresses to connect
*
* @param host hostname or IP
* @param port port
* @param sslMode ssl mode. possible values disable/trust/verify-ca/verify-full
* @return this {@link Builder}
*/
public Builder addHost(String host, int port, String sslMode) {
this._addresses.add(HostAddress.from(nullOrEmpty(host), port, sslMode));
return this;
}

/**
* Add Host to possible addresses to connect
*
Expand All @@ -2348,6 +2361,20 @@ public Builder addHost(String host, int port, boolean master) {
return this;
}

/**
* Add Host to possible addresses to connect
*
* @param host hostname or IP
* @param port port
* @param master is master or replica
* @param sslMode ssl mode. possible values disable/trust/verify-ca/verify-full
* @return this {@link Builder}
*/
public Builder addHost(String host, int port, boolean master, String sslMode) {
this._addresses.add(HostAddress.from(nullOrEmpty(host), port, master, sslMode));
return this;
}

public Builder addPipeHost(String pipe) {
this._addresses.add(HostAddress.pipe(pipe));
return this;
Expand Down
62 changes: 51 additions & 11 deletions src/main/java/org/mariadb/jdbc/HostAddress.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.List;
import java.util.Objects;
import org.mariadb.jdbc.export.HaMode;
import org.mariadb.jdbc.export.SslMode;

/** Host entry */
public class HostAddress {
Expand All @@ -19,6 +20,7 @@ public class HostAddress {

public String pipe;

public SslMode sslMode;
public String localSocket;

/** primary node */
Expand All @@ -34,12 +36,14 @@ public class HostAddress {
* @param port port
* @param primary is primary
*/
private HostAddress(String host, int port, Boolean primary, String pipe, String localSocket) {
private HostAddress(
String host, int port, Boolean primary, String pipe, String localSocket, String sslMode) {
this.host = host;
this.port = port;
this.primary = primary;
this.pipe = pipe;
this.localSocket = localSocket;
this.sslMode = sslMode == null ? null : SslMode.from(sslMode);
}

/**
Expand All @@ -50,15 +54,15 @@ private HostAddress(String host, int port, Boolean primary, String pipe, String
* @return host
*/
public static HostAddress from(String host, int port) {
return new HostAddress(host, port, null, null, null);
return new HostAddress(host, port, null, null, null, null);
}

public static HostAddress pipe(String pipe) {
return new HostAddress(null, 3306, null, pipe, null);
return new HostAddress(null, 3306, null, pipe, null, null);
}

public static HostAddress localSocket(String localSocket) {
return new HostAddress(null, 3306, null, null, localSocket);
return new HostAddress(null, 3306, null, null, localSocket, null);
}

/**
Expand All @@ -70,7 +74,32 @@ public static HostAddress localSocket(String localSocket) {
* @return host
*/
public static HostAddress from(String host, int port, boolean primary) {
return new HostAddress(host, port, primary, null, null);
return new HostAddress(host, port, primary, null, null, null);
}

/**
* Create a Host
*
* @param host host (DNS/IP)
* @param port port
* @param sslMode ssl mode
* @return host
*/
public static HostAddress from(String host, int port, String sslMode) {
return new HostAddress(host, port, null, null, null, sslMode);
}

/**
* Create a Host
*
* @param host host (DNS/IP)
* @param port port
* @param primary is primary
* @param sslMode ssl mode
* @return host
*/
public static HostAddress from(String host, int port, boolean primary, String sslMode) {
return new HostAddress(host, port, primary, null, null, sslMode);
}

/**
Expand Down Expand Up @@ -126,7 +155,7 @@ private static HostAddress parseSimpleHostAddress(String str, HaMode haMode, boo

boolean primary = haMode != HaMode.REPLICATION || first;

return new HostAddress(host, port, primary, null, null);
return new HostAddress(host, port, primary, null, null, null);
}

private static int getPort(String portString) throws SQLException {
Expand All @@ -141,6 +170,9 @@ private static HostAddress parseParameterHostAddress(String str, HaMode haMode,
throws SQLException {
String host = null;
int port = 3306;
String sslMode = null;
String pipe = null;
String localsocket = null;
Boolean primary = null;

String[] array = str.replace(" ", "").split("(?=\\()|(?<=\\))");
Expand All @@ -158,12 +190,17 @@ private static HostAddress parseParameterHostAddress(String str, HaMode haMode,
host = value.replace("[", "").replace("]", "");
break;
case "localsocket":
return new HostAddress(null, 3306, null, null, token[1]);
localsocket = token[1];
break;
case "pipe":
return new HostAddress(null, 3306, null, token[1], null);
pipe = token[1];
break;
case "port":
port = getPort(value);
break;
case "sslmode":
sslMode = token[1];
break;
case "type":
if ("master".equalsIgnoreCase(value) || "primary".equalsIgnoreCase(value)) {
primary = true;
Expand All @@ -185,16 +222,19 @@ private static HostAddress parseParameterHostAddress(String str, HaMode haMode,
}
}

return new HostAddress(host, port, primary, null, null);
return new HostAddress(host, port, primary, pipe, localsocket, sslMode);
}

@Override
public String toString() {
if (pipe != null) return String.format("address=(pipe=%s)", pipe);
if (localSocket != null) return String.format("address=(localSocket=%s)", localSocket);
return String.format(
"address=(host=%s)(port=%s)%s",
host, port, ((primary != null) ? ("(type=" + (primary ? "primary)" : "replica)")) : ""));
"address=(host=%s)(port=%s)%s%s",
host,
port,
(sslMode != null) ? "(sslMode=" + sslMode.getValue() + ")" : "",
((primary != null) ? ("(type=" + (primary ? "primary)" : "replica)")) : ""));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ public static long initializeClientCapabilities(
&& (hostAddress != null && !hostAddress.primary)))) {
capabilities |= Capabilities.CONNECT_WITH_DB;
}

if (configuration.sslMode() != SslMode.DISABLE) {
SslMode sslMode = hostAddress.sslMode == null ? configuration.sslMode() : hostAddress.sslMode;
if (sslMode != SslMode.DISABLE) {
capabilities |= Capabilities.SSL;
}
return capabilities & serverCapabilities;
Expand Down
16 changes: 10 additions & 6 deletions src/main/java/org/mariadb/jdbc/client/impl/StandardClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public StandardClient(
: new NativePasswordPlugin();
writer.flush();

authenticationHandler(credential);
authenticationHandler(credential, hostAddress);

// **********************************************************************
// activate compression if required
Expand Down Expand Up @@ -259,10 +259,12 @@ public StandardClient(

/**
* @param credential credential
* @param hostAddress host address
* @throws IOException if any socket error occurs
* @throws SQLException if any other kind of issue occurs
*/
public void authenticationHandler(Credential credential) throws IOException, SQLException {
public void authenticationHandler(Credential credential, HostAddress hostAddress)
throws IOException, SQLException {

writer.permitTrace(true);
Configuration conf = context.getConf();
Expand All @@ -288,7 +290,8 @@ public void authenticationHandler(Credential credential) throws IOException, SQL
"08000");
}

authPlugin.initialize(credential.getPassword(), authSwitchPacket.getSeed(), conf);
authPlugin.initialize(
credential.getPassword(), authSwitchPacket.getSeed(), conf, hostAddress);
buf = authPlugin.process(writer, reader, context);
break;

Expand Down Expand Up @@ -473,7 +476,8 @@ public SSLSocket sslWrapper(
throws IOException, SQLException {

Configuration conf = context.getConf();
if (conf.sslMode() != SslMode.DISABLE) {
SslMode sslMode = hostAddress.sslMode == null ? conf.sslMode() : hostAddress.sslMode;
if (sslMode != SslMode.DISABLE) {

if (!context.hasServerCapability(Capabilities.SSL)) {
throw context
Expand All @@ -487,7 +491,7 @@ public SSLSocket sslWrapper(
TlsSocketPlugin socketPlugin = TlsSocketPluginLoader.get(conf.tlsSocketType());
SSLSocketFactory sslSocketFactory;
TrustManager[] trustManagers =
socketPlugin.getTrustManager(conf, context.getExceptionFactory());
socketPlugin.getTrustManager(conf, context.getExceptionFactory(), hostAddress);
try {
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(
Expand Down Expand Up @@ -518,7 +522,7 @@ public SSLSocket sslWrapper(
// (rfc2818 indicate that if "client has external information as to the expected identity of
// the server, the hostname check MAY be omitted")
// validation is only done for not "self-signed" certificates
if (certFingerprint == null && conf.sslMode() == SslMode.VERIFY_FULL && hostAddress != null) {
if (certFingerprint == null && sslMode == SslMode.VERIFY_FULL && hostAddress.host != null) {
SSLSession session = sslSocket.getSession();
try {
socketPlugin.verify(hostAddress.host, session, context.getThreadId());
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/mariadb/jdbc/export/SslMode.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public enum SslMode {
this.aliases = aliases;
}

public String getValue() {
return value;
}

/**
* Create SSLMode from enumeration value, or aliases
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.IOException;
import java.sql.SQLException;
import org.mariadb.jdbc.Configuration;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.ReadableByteBuf;
import org.mariadb.jdbc.client.socket.Reader;
Expand All @@ -27,8 +28,10 @@ public interface AuthenticationPlugin {
* @param authenticationData authentication data (password/token)
* @param seed server provided seed
* @param conf Connection options
* @param hostAddress host address
*/
void initialize(String authenticationData, byte[] seed, Configuration conf);
void initialize(
String authenticationData, byte[] seed, Configuration conf, HostAddress hostAddress);

/**
* Process plugin authentication.
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/org/mariadb/jdbc/plugin/TlsSocketPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.sql.SQLException;
import javax.net.ssl.*;
import org.mariadb.jdbc.Configuration;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.export.ExceptionFactory;

/** TLS Socket interface plugin */
Expand All @@ -20,7 +21,8 @@ public interface TlsSocketPlugin {
*/
String type();

TrustManager[] getTrustManager(Configuration conf, ExceptionFactory exceptionFactory)
TrustManager[] getTrustManager(
Configuration conf, ExceptionFactory exceptionFactory, HostAddress hostAddress)
throws SQLException;

KeyManager[] getKeyManager(Configuration conf, ExceptionFactory exceptionFactory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import org.mariadb.jdbc.Configuration;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.ReadableByteBuf;
import org.mariadb.jdbc.client.socket.Reader;
Expand All @@ -29,7 +30,16 @@ public boolean requireSsl() {
return true;
}

public void initialize(String authenticationData, byte[] authData, Configuration conf) {
/**
* Initialization.
*
* @param authenticationData authentication data (password/token)
* @param seed server provided seed
* @param conf Connection string options
* @param hostAddress host information
*/
public void initialize(
String authenticationData, byte[] seed, Configuration conf, HostAddress hostAddress) {
this.authenticationData = authenticationData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.IOException;
import java.sql.SQLException;
import org.mariadb.jdbc.Configuration;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.ReadableByteBuf;
import org.mariadb.jdbc.client.impl.StandardReadableByteBuf;
Expand Down Expand Up @@ -45,8 +46,10 @@ public String type() {
* @param authenticationData authentication data (password/token)
* @param seed server provided seed
* @param conf Connection string options
* @param hostAddress host information
*/
public void initialize(String authenticationData, byte[] seed, Configuration conf) {
public void initialize(
String authenticationData, byte[] seed, Configuration conf, HostAddress hostAddress) {
this.seed = seed;
this.optionServicePrincipalName = conf.servicePrincipalName();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.Base64;
import javax.crypto.Cipher;
import org.mariadb.jdbc.Configuration;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.ReadableByteBuf;
import org.mariadb.jdbc.client.socket.Reader;
Expand All @@ -35,6 +36,7 @@ public class CachingSha2PasswordPlugin implements AuthenticationPlugin {
private String authenticationData;
private byte[] seed;
private Configuration conf;
private HostAddress hostAddress;

/**
* Send an SHA-2 encrypted password. encryption XOR(SHA256(password), SHA256(seed,
Expand Down Expand Up @@ -159,11 +161,14 @@ public String type() {
* @param authenticationData authentication data (password/token)
* @param seed server provided seed
* @param conf Connection string options
* @param hostAddress host information
*/
public void initialize(String authenticationData, byte[] seed, Configuration conf) {
public void initialize(
String authenticationData, byte[] seed, Configuration conf, HostAddress hostAddress) {
this.seed = seed;
this.authenticationData = authenticationData;
this.conf = conf;
this.hostAddress = hostAddress;
}

/**
Expand Down Expand Up @@ -197,7 +202,8 @@ public ReadableByteBuf process(Writer out, Reader in, Context context)
case 3:
return in.readReusablePacket();
case 4:
if (conf.sslMode() != SslMode.DISABLE) {
SslMode sslMode = hostAddress.sslMode == null ? conf.sslMode() : hostAddress.sslMode;
if (sslMode != SslMode.DISABLE) {
// send clear password

byte[] bytePwd = authenticationData.getBytes();
Expand Down
Loading

0 comments on commit e8349a8

Please sign in to comment.