Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support createDatabaseIfNotExist #162

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add support for create database if not exist
- Support `InitDbMessage`.
- Support `changeDatabase` in `MySqlConnection`.
- Add integration tests for that.
mirromutth committed Dec 15, 2023
commit f22c077fdf5d2ce9e5d217946fe53d3cfddd9ed9
2 changes: 2 additions & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/Capability.java
Original file line number Diff line number Diff line change
@@ -26,6 +26,8 @@ public final class Capability {

/**
* Can use long password.
* <p>
* TODO: Reinterpret it as {@code CLIENT_MYSQL} to support MariaDB 10.2 and above.
*/
private static final int LONG_PASSWORD = 1;

55 changes: 52 additions & 3 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.message.client.InitDbMessage;
import io.asyncer.r2dbc.mysql.message.client.PingMessage;
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
@@ -91,6 +92,31 @@ public final class MySqlConnection implements Connection, ConnectionState {
}
};

private static final BiConsumer<ServerMessage, SynchronousSink<Boolean>> INIT_DB = (message, sink) -> {
if (message instanceof ErrorMessage) {
ErrorMessage msg = (ErrorMessage) message;
logger.debug("Use database failed: [{}] [{}] {}", msg.getCode(), msg.getSqlState(),
msg.getMessage());
sink.next(false);
sink.complete();
} else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
sink.next(true);
sink.complete();
} else {
ReferenceCountUtil.safeRelease(message);
}
};

private static final BiConsumer<ServerMessage, SynchronousSink<Void>> INIT_DB_AFTER = (message, sink) -> {
if (message instanceof ErrorMessage) {
sink.error(((ErrorMessage) message).toException());
} else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
sink.complete();
} else {
ReferenceCountUtil.safeRelease(message);
}
};

private final Client client;

private final Codecs codecs;
@@ -403,13 +429,17 @@ boolean isSessionAutoCommit() {
* @param client must be logged-in.
* @param codecs the {@link Codecs}.
* @param context must be initialized.
* @param database the database that should be lazy init.
* @param queryCache the cache of {@link Query}.
* @param prepareCache the cache of server-preparing result.
* @param prepare judging for prefer use prepare statement to execute simple query.
* @return a {@link Mono} will emit an initialized {@link MySqlConnection}.
*/
static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContext context,
QueryCache queryCache, PrepareCache prepareCache, @Nullable Predicate<String> prepare) {
static Mono<MySqlConnection> init(
Client client, Codecs codecs, ConnectionContext context, String database,
QueryCache queryCache, PrepareCache prepareCache,
@Nullable Predicate<String> prepare
) {
ServerVersion version = context.getServerVersion();
StringBuilder query = new StringBuilder(128);

@@ -431,7 +461,7 @@ static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContex
handler = MySqlConnection::init;
}

return new TextSimpleStatement(client, codecs, context, query.toString())
Mono<MySqlConnection> connection = new TextSimpleStatement(client, codecs, context, query.toString())
.execute()
.flatMap(handler)
.last()
@@ -445,6 +475,25 @@ static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContex
return new MySqlConnection(client, context, codecs, data.level, data.lockWaitTimeout,
queryCache, prepareCache, data.product, prepare);
});

if (database.isEmpty()) {
return connection;
}

requireValidName(database, "database must not be empty and not contain backticks");

return connection.flatMap(conn -> client.exchange(new InitDbMessage(database), INIT_DB)
.last()
.flatMap(success -> {
if (success) {
return Mono.just(conn);
}

String sql = String.format("CREATE DATABASE IF NOT EXISTS `%s`", database);

return QueryFlow.executeVoid(client, sql)
.then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then(Mono.just(conn)));
}));
}

private static Publisher<InitData> init(MySqlResult r) {
41 changes: 25 additions & 16 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java
Original file line number Diff line number Diff line change
@@ -28,7 +28,6 @@
import io.netty.channel.unix.DomainSocketAddress;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
@@ -86,6 +85,7 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
}

String database = configuration.getDatabase();
boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist();
String user = configuration.getUser();
CharSequence password = configuration.getPassword();
SslMode sslMode = ssl.getSslMode();
@@ -95,32 +95,36 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
Predicate<String> prepare = configuration.getPreferPrepareStatement();
int prepareCacheSize = configuration.getPrepareCacheSize();
Publisher<String> passwordPublisher = configuration.getPasswordPublisher();

if (Objects.nonNull(passwordPublisher)) {
return Mono.from(passwordPublisher)
.flatMap(token -> getMySqlConnection(
configuration, queryCache,
ssl, address,
database, user,
sslMode, context,
extensions, prepare,
prepareCacheSize, token));
return Mono.from(passwordPublisher).flatMap(token -> getMySqlConnection(
configuration, queryCache,
ssl, address,
database, createDbIfNotExist,
user, sslMode, context,
extensions, prepare,
prepareCacheSize, token
));
}
return getMySqlConnection(configuration, queryCache,

return getMySqlConnection(
configuration, queryCache,
ssl, address,
database, user,
sslMode, context,
database, createDbIfNotExist,
user, sslMode, context,
extensions, prepare,
prepareCacheSize, password);
prepareCacheSize, password
);
}));
}

@NotNull
private static Mono<MySqlConnection> getMySqlConnection(
final MySqlConnectionConfiguration configuration,
final LazyQueryCache queryCache,
final MySqlSslConfiguration ssl,
final SocketAddress address,
final String database,
final boolean createDbIfNotExist,
final String user,
final SslMode sslMode,
final ConnectionContext context,
@@ -130,16 +134,21 @@ private static Mono<MySqlConnection> getMySqlConnection(
@Nullable final CharSequence password) {
return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(),
context, configuration.getConnectTimeout(), configuration.getSocketTimeout())
.flatMap(client -> QueryFlow.login(client, sslMode, database, user, password, context))
.flatMap(client -> {
// Lazy init database after handshake/login
String db = createDbIfNotExist ? "" : database;
return QueryFlow.login(client, sslMode, db, user, password, context);
})
.flatMap(client -> {
ByteBufAllocator allocator = client.getByteBufAllocator();
CodecsBuilder builder = Codecs.builder(allocator);
PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize);
String db = createDbIfNotExist ? database : "";

extensions.forEach(CodecRegistrar.class, registrar ->
registrar.register(allocator, builder));

return MySqlConnection.init(client, builder.build(), context, queryCache.get(),
return MySqlConnection.init(client, builder.build(), context, db, queryCache.get(),
prepareCache, prepare);
});
}
1 change: 1 addition & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
@@ -888,6 +888,7 @@ private Capability clientCapability(Capability serverCapability) {

builder.disableDatabasePinned();
builder.disableCompression();
// TODO: support LOAD DATA LOCAL INFILE
builder.disableLoadDataInfile();
builder.disableIgnoreAmbiguitySpace();
builder.disableInteractiveTimeout();
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.asyncer.r2dbc.mysql.message.client;

import io.asyncer.r2dbc.mysql.ConnectionContext;
import io.netty.buffer.ByteBuf;

public final class InitDbMessage extends ScalarClientMessage {

private static final byte FLAG = 0x02;

private final String database;

public InitDbMessage(String database) { this.database = database; }

@Override
protected void writeTo(ByteBuf buf, ConnectionContext context) {
// RestOfPacketString, no need terminal or length
buf.writeByte(FLAG).writeCharSequence(database, context.getClientCollation().getCharset());
}
}
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@
class ConnectionIntegrationTest extends IntegrationTestSupport {

ConnectionIntegrationTest() {
super(configuration(false, null, null));
super(configuration("r2dbc", false, false, null, null));
}

@Test
35 changes: 35 additions & 0 deletions src/test/java/io/asyncer/r2dbc/mysql/InitDbIntegrationTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.asyncer.r2dbc.mysql;

import org.junit.jupiter.api.Test;

import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Integration tests for {@code createDatabaseIfNotExist}.
*/
class InitDbIntegrationTest extends IntegrationTestSupport {

private static final String DATABASE = "test-" + ThreadLocalRandom.current().nextInt(10000);

InitDbIntegrationTest() {
super(configuration(
DATABASE, true, false,
null, null
));
}

@Test
void shouldCreateDatabase() {
complete(conn -> conn.createStatement("SHOW DATABASES")
.execute()
.flatMap(it -> it.map((row, rowMetadata) -> row.get(0, String.class)))
.collect(Collectors.toSet())
.doOnNext(it -> assertThat(it).contains(DATABASE))
.thenMany(conn.createStatement("DROP DATABASE `" + DATABASE + "`")
.execute()
.flatMap(MySqlResult::getRowsUpdated)));
}
}
Original file line number Diff line number Diff line change
@@ -71,8 +71,10 @@ static Mono<Long> extractRowsUpdated(Result result) {
return Mono.from(result.getRowsUpdated());
}

static MySqlConnectionConfiguration configuration(boolean autodetectExtensions,
@Nullable ZoneId serverZoneId, @Nullable Predicate<String> preferPrepared) {
static MySqlConnectionConfiguration configuration(
String database, boolean createDatabaseIfNotExist, boolean autodetectExtensions,
@Nullable ZoneId serverZoneId, @Nullable Predicate<String> preferPrepared
) {
String password = System.getProperty("test.mysql.password");

assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty")
@@ -84,7 +86,8 @@ static MySqlConnectionConfiguration configuration(boolean autodetectExtensions,
.connectTimeout(Duration.ofSeconds(3))
.user("root")
.password(password)
.database("r2dbc")
.database(database)
.createDatabaseIfNotExist(createDatabaseIfNotExist)
.autodetectExtensions(autodetectExtensions);

if (serverZoneId != null) {
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
class JacksonPrepareIntegrationTest extends JacksonIntegrationTestSupport {

JacksonPrepareIntegrationTest() {
super(configuration(true, null, sql -> false));
super(configuration("r2dbc", false, true, null, sql -> false));
}
}

Original file line number Diff line number Diff line change
@@ -22,6 +22,6 @@
class JacksonTextIntegrationTest extends JacksonIntegrationTestSupport {

JacksonTextIntegrationTest() {
super(configuration(true, null, null));
super(configuration("r2dbc", false, true, null, null));
}
}
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
class MySqlPrepareTestKit extends MySqlTestKitSupport {

MySqlPrepareTestKit() {
super(IntegrationTestSupport.configuration(false, null, sql -> true));
super(IntegrationTestSupport.configuration("r2dbc", false, false, null, sql -> true));
}

@Override
2 changes: 1 addition & 1 deletion src/test/java/io/asyncer/r2dbc/mysql/MySqlTextTestKit.java
Original file line number Diff line number Diff line change
@@ -22,6 +22,6 @@
class MySqlTextTestKit extends MySqlTestKitSupport {

MySqlTextTestKit() {
super(IntegrationTestSupport.configuration(false, null, null));
super(IntegrationTestSupport.configuration("r2dbc", false, false, null, null));
}
}
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
class PrepareQueryIntegrationTest extends QueryIntegrationTestSupport {

PrepareQueryIntegrationTest() {
super(configuration(false, null, sql -> true));
super(configuration("r2dbc", false, false, null, sql -> true));
}

@Test
Original file line number Diff line number Diff line change
@@ -22,6 +22,6 @@
class TextQueryIntegrationTest extends QueryIntegrationTestSupport {

TextQueryIntegrationTest() {
super(configuration(false, null, null));
super(configuration("r2dbc", false, false, null, null));
}
}
Original file line number Diff line number Diff line change
@@ -66,7 +66,7 @@ abstract class TimeZoneIntegrationTestSupport extends IntegrationTestSupport {
}

TimeZoneIntegrationTestSupport(@Nullable Predicate<String> preferPrepared) {
super(configuration(false, SERVER_ZONE, preferPrepared));
super(configuration("r2dbc", false, false, SERVER_ZONE, preferPrepared));
}

@Test