diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java index a17517cc1..92ea0aa56 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java @@ -31,6 +31,7 @@ import io.netty.util.internal.logging.InternalLoggerFactory; import io.r2dbc.spi.Connection; import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.Lifecycle; import io.r2dbc.spi.TransactionDefinition; import io.r2dbc.spi.ValidationDepth; import org.jetbrains.annotations.Nullable; @@ -52,7 +53,7 @@ /** * An implementation of {@link Connection} for connecting to the MySQL database. */ -public final class MySqlConnection implements Connection, ConnectionState { +public final class MySqlConnection implements Connection, Lifecycle, ConnectionState { private static final InternalLogger logger = InternalLoggerFactory.getInstance(MySqlConnection.class); @@ -278,6 +279,17 @@ public MySqlStatement createStatement(String sql) { return new PrepareParametrizedStatement(client, codecs, query, context, prepareCache); } + @Override + public Mono postAllocate() { + return Mono.empty(); + } + + @Override + public Mono preRelease() { + // Rollback if the connection is in transaction. + return rollbackTransaction(); + } + @Override public Mono releaseSavepoint(String name) { requireValidName(name, "Savepoint name must not be empty and not contain backticks"); diff --git a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java index 18eb12ff8..cb51f38ee 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java @@ -22,6 +22,7 @@ import java.time.Duration; import java.util.Arrays; +import java.util.Collections; import static io.r2dbc.spi.IsolationLevel.READ_COMMITTED; import static io.r2dbc.spi.IsolationLevel.READ_UNCOMMITTED; @@ -52,6 +53,54 @@ void isInTransaction() { .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())); } + @Test + void autoRollbackPreRelease() { + // Mock pool allocate/release. + complete(conn -> conn.postAllocate() + .thenMany(conn.createStatement("CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY)") + .execute()) + .flatMap(MySqlResult::getRowsUpdated) + .then(conn.beginTransaction()) + .thenMany(conn.createStatement("INSERT INTO test VALUES (1)") + .execute()) + .flatMap(MySqlResult::getRowsUpdated) + .single() + .doOnNext(it -> assertThat(it).isEqualTo(1)) + .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isTrue()) + .then(conn.preRelease()) + .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isFalse()) + .then(conn.postAllocate()) + .thenMany(conn.createStatement("SELECT * FROM test") + .execute()) + .flatMap(it -> it.map((row, metadata) -> row.get(0, Integer.class))) + .count() + .doOnNext(it -> assertThat(it).isZero())); + } + + @Test + void shouldNotRollbackCommittedPreRelease() { + // Mock pool allocate/release. + complete(conn -> conn.postAllocate() + .thenMany(conn.createStatement("CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY)") + .execute()) + .flatMap(MySqlResult::getRowsUpdated) + .then(conn.beginTransaction()) + .thenMany(conn.createStatement("INSERT INTO test VALUES (1)") + .execute()) + .flatMap(MySqlResult::getRowsUpdated) + .single() + .doOnNext(it -> assertThat(it).isEqualTo(1)) + .then(conn.commitTransaction()) + .then(conn.preRelease()) + .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isFalse()) + .then(conn.postAllocate()) + .thenMany(conn.createStatement("SELECT * FROM test") + .execute()) + .flatMap(it -> it.map((row, metadata) -> row.get(0, Integer.class))) + .collectList() + .doOnNext(it -> assertThat(it).isEqualTo(Collections.singletonList(1)))); + } + @Test void transactionDefinitionLockWaitTimeout() { complete(connection -> connection.beginTransaction(MySqlTransactionDefinition.builder()