From 581329d02398c14daa54e8761abf441a747cfa96 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 18 Feb 2020 14:18:16 +0100 Subject: [PATCH 1/5] #189 - Prepare issue branch. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 6e43082e..e1b445b2 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-r2dbc - 1.1.0.BUILD-SNAPSHOT + 1.1.0.gh-189-SNAPSHOT Spring Data R2DBC Spring Data module for R2DBC From 1675af74379bcf3e5164429bd3b442042f6c9cec Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 18 Feb 2020 14:45:28 +0100 Subject: [PATCH 2/5] #189 - Accept StatementFilterFunction in DatabaseClient. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We now accept StatementFilterFunction and ExecuteFunction via DatabaseClient to filter Statement execution. StatementFilterFunctions can be used to pre-process the statement or post-process Result objects. databaseClient.execute(…) .filter((s, next) -> next.execute(s.returnGeneratedValues("my_id"))) .filter((s, next) -> next.execute(s.fetchSize(25))) databaseClient.execute(…) .filter(s -> s.returnGeneratedValues("my_id")) .filter(s -> s.fetchSize(25)) --- src/main/asciidoc/new-features.adoc | 3 +- src/main/asciidoc/reference/r2dbc-sql.adoc | 38 +++++- .../data/r2dbc/core/DatabaseClient.java | 44 ++++++- .../r2dbc/core/DefaultDatabaseClient.java | 99 +++++++++------ .../core/DefaultDatabaseClientBuilder.java | 21 +++- .../data/r2dbc/core/ExecuteFunction.java | 46 +++++++ .../r2dbc/core/StatementFilterFunction.java | 65 ++++++++++ .../r2dbc/core/StatementFilterFunctions.java | 46 +++++++ .../core/DefaultDatabaseClientUnitTests.java | 114 ++++++++++++++++++ 9 files changed, 434 insertions(+), 42 deletions(-) create mode 100644 src/main/java/org/springframework/data/r2dbc/core/ExecuteFunction.java create mode 100644 src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java create mode 100644 src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunctions.java diff --git a/src/main/asciidoc/new-features.adoc b/src/main/asciidoc/new-features.adoc index 6f468840..07d4504e 100644 --- a/src/main/asciidoc/new-features.adoc +++ b/src/main/asciidoc/new-features.adoc @@ -5,7 +5,8 @@ == What's New in Spring Data R2DBC 1.1.0 RELEASE * Introduction of `R2dbcEntityTemplate` for entity-oriented operations. -* Support interface projections with `DatabaseClient.as(…)` +* Support interface projections with `DatabaseClient.as(…)`. +* <>. [[new-features.1-0-0-RELEASE]] == What's New in Spring Data R2DBC 1.0.0 RELEASE diff --git a/src/main/asciidoc/reference/r2dbc-sql.adoc b/src/main/asciidoc/reference/r2dbc-sql.adoc index 58c0ce07..7a750855 100644 --- a/src/main/asciidoc/reference/r2dbc-sql.adoc +++ b/src/main/asciidoc/reference/r2dbc-sql.adoc @@ -134,7 +134,7 @@ In JDBC, the actual drivers translate `?` bind markers to database-native marker Spring Data R2DBC lets you use native bind markers or named bind markers with the `:name` syntax. -Named parameter support leverages a `R2dbcDialect` instance to expand named parameters to native bind markers at the time of query execution, which gives you a certain degree of query portability across various database vendors. +Named parameter support leverages a `R2dbcDialect` instance to expand named parameters to native bind markers at the time of query execution, which gives you a certain degree of query portability across various database vendors. **** The query-preprocessor unrolls named `Collection` parameters into a series of bind markers to remove the need of dynamic query creation based on the number of arguments. @@ -159,7 +159,7 @@ tuples.add(new Object[] {"John", 35}); tuples.add(new Object[] {"Ann", 50}); db.execute("SELECT id, name, state FROM table WHERE (name, age) IN (:tuples)") - .bind("tuples", tuples); + .bind("tuples", tuples) ---- ==== @@ -171,6 +171,38 @@ The following example shows a simpler variant using `IN` predicates: [source,java] ---- db.execute("SELECT id, name, state FROM table WHERE age IN (:ages)") - .bind("ages", Arrays.asList(35, 50)); + .bind("ages", Arrays.asList(35, 50)) ---- ==== + +[[r2dbc.datbaseclient.filter]] +== Statement Filters + +You can register a `Statement` filter (`StatementFilterFunction`) through `DatabaseClient` to intercept and modify statements in their execution, as the following example shows: + +==== +[source,java] +---- +db.execute("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter((s, next) -> next.execute(s.returnGeneratedValues("id"))) + .bind("name", …) + .bind("state", …) +---- +==== + +`DatabaseClient` exposes also simplified `filter(…)` overload accepting `UnaryOperator`: + +==== +[source,java] +---- +db.execute("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter(s -> s.returnGeneratedValues("id")) + .bind("name", …) + .bind("state", …) + +db.execute("SELECT id, name, state FROM table") + .filter(s -> s.fetchSize(25)) +---- +==== + +`StatementFilterFunction` allow filtering of the executed `Statement` and filtering of `Result` objects. diff --git a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java index 7fde95f6..613917ef 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java @@ -18,6 +18,7 @@ import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; +import io.r2dbc.spi.Statement; import reactor.core.publisher.Mono; import java.util.Arrays; @@ -26,6 +27,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import java.util.function.UnaryOperator; import org.reactivestreams.Publisher; @@ -37,6 +39,7 @@ import org.springframework.data.r2dbc.query.Update; import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.util.Assert; /** * A non-blocking, reactive client for performing database calls requests with Reactive Streams back pressure. Provides @@ -142,6 +145,16 @@ interface Builder { */ Builder exceptionTranslator(R2dbcExceptionTranslator exceptionTranslator); + /** + * Configures a {@link ExecuteFunction} to execute {@link Statement} objects. + * + * @param executeFunction must not be {@literal null}. + * @return {@code this} {@link Builder}. + * @since 1.1 + * @see Statement#execute() + */ + Builder executeFunction(ExecuteFunction executeFunction); + /** * Configures a {@link ReactiveDataAccessStrategy}. * @@ -186,7 +199,7 @@ interface Builder { /** * Contract for specifying a SQL call along with options leading to the exchange. */ - interface GenericExecuteSpec extends BindSpec { + interface GenericExecuteSpec extends BindSpec, StatementFilterSpec { /** * Define the target type the result should be mapped to.
@@ -231,7 +244,7 @@ interface GenericExecuteSpec extends BindSpec { /** * Contract for specifying a SQL call along with options leading to the exchange. */ - interface TypedExecuteSpec extends BindSpec> { + interface TypedExecuteSpec extends BindSpec>, StatementFilterSpec> { /** * Define the target type the result should be mapped to.
@@ -866,4 +879,31 @@ interface BindSpec> { */ S bindNull(String name, Class type); } + + /** + * Contract for applying a {@link StatementFilterFunction}. + * + * @since 1.1 + */ + interface StatementFilterSpec> { + + /** + * Add the given filter to the end of the filter chain. + * + * @param filter the filter to be added to the chain. + */ + default S filter(UnaryOperator filter) { + + Assert.notNull(filter, "Statement FilterFunction must not be null!"); + + return filter((statement, next) -> next.execute(filter.apply(statement))); + } + + /** + * Add the given filter to the end of the filter chain. + * + * @param filter the filter to be added to the chain. + */ + S filter(StatementFilterFunction filter); + } } diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java index 7dc37bfe..f95b356d 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java @@ -78,6 +78,8 @@ class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor { private final R2dbcExceptionTranslator exceptionTranslator; + private final ExecuteFunction executeFunction; + private final ReactiveDataAccessStrategy dataAccessStrategy; private final boolean namedParameters; @@ -87,11 +89,12 @@ class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor { private final ProjectionFactory projectionFactory; DefaultDatabaseClient(ConnectionFactory connector, R2dbcExceptionTranslator exceptionTranslator, - ReactiveDataAccessStrategy dataAccessStrategy, boolean namedParameters, ProjectionFactory projectionFactory, - DefaultDatabaseClientBuilder builder) { + ExecuteFunction executeFunction, ReactiveDataAccessStrategy dataAccessStrategy, boolean namedParameters, + ProjectionFactory projectionFactory, DefaultDatabaseClientBuilder builder) { this.connector = connector; this.exceptionTranslator = exceptionTranslator; + this.executeFunction = executeFunction; this.dataAccessStrategy = dataAccessStrategy; this.namedParameters = namedParameters; this.projectionFactory = projectionFactory; @@ -264,25 +267,26 @@ protected DataAccessException translateException(String task, @Nullable String s * Customization hook. */ protected DefaultTypedExecuteSpec createTypedExecuteSpec(Map byIndex, - Map byName, Supplier sqlSupplier, Class typeToRead) { - return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, typeToRead); + Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction, + Class typeToRead) { + return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, filterFunction, typeToRead); } /** * Customization hook. */ protected DefaultTypedExecuteSpec createTypedExecuteSpec(Map byIndex, - Map byName, Supplier sqlSupplier, + Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction, BiFunction mappingFunction) { - return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, mappingFunction); + return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, filterFunction, mappingFunction); } /** * Customization hook. */ protected ExecuteSpecSupport createGenericExecuteSpec(Map byIndex, - Map byName, Supplier sqlSupplier) { - return new DefaultGenericExecuteSpec(byIndex, byName, sqlSupplier); + Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction) { + return new DefaultGenericExecuteSpec(byIndex, byName, sqlSupplier, filterFunction); } /** @@ -327,19 +331,22 @@ class ExecuteSpecSupport { final Map byIndex; final Map byName; final Supplier sqlSupplier; + final StatementFilterFunction filterFunction; ExecuteSpecSupport(Supplier sqlSupplier) { this.byIndex = Collections.emptyMap(); this.byName = Collections.emptyMap(); this.sqlSupplier = sqlSupplier; + this.filterFunction = StatementFilterFunctions.empty(); } ExecuteSpecSupport(Map byIndex, Map byName, - Supplier sqlSupplier) { + Supplier sqlSupplier, StatementFilterFunction filterFunction) { this.byIndex = byIndex; this.byName = byName; this.sqlSupplier = sqlSupplier; + this.filterFunction = filterFunction; } FetchSpec exchange(Supplier sqlSupplier, BiFunction mappingFunction) { @@ -404,7 +411,7 @@ FetchSpec exchange(Supplier sqlSupplier, BiFunction> resultFunction = toExecuteFunction(sql, executeFunction); + Function> resultFunction = toFunction(sql, filterFunction, executeFunction); return new DefaultSqlResult<>(DefaultDatabaseClient.this, // sql, // @@ -426,7 +433,7 @@ public ExecuteSpecSupport bind(int index, Object value) { byIndex.put(index, SettableValue.fromOrEmpty(value, value.getClass())); } - return createInstance(byIndex, this.byName, this.sqlSupplier); + return createInstance(byIndex, this.byName, this.sqlSupplier, this.filterFunction); } public ExecuteSpecSupport bindNull(int index, Class type) { @@ -436,7 +443,7 @@ public ExecuteSpecSupport bindNull(int index, Class type) { Map byIndex = new LinkedHashMap<>(this.byIndex); byIndex.put(index, SettableValue.empty(type)); - return createInstance(byIndex, this.byName, this.sqlSupplier); + return createInstance(byIndex, this.byName, this.sqlSupplier, this.filterFunction); } public ExecuteSpecSupport bind(String name, Object value) { @@ -455,7 +462,7 @@ public ExecuteSpecSupport bind(String name, Object value) { byName.put(name, SettableValue.fromOrEmpty(value, value.getClass())); } - return createInstance(this.byIndex, byName, this.sqlSupplier); + return createInstance(this.byIndex, byName, this.sqlSupplier, this.filterFunction); } public ExecuteSpecSupport bindNull(String name, Class type) { @@ -466,7 +473,14 @@ public ExecuteSpecSupport bindNull(String name, Class type) { Map byName = new LinkedHashMap<>(this.byName); byName.put(name, SettableValue.empty(type)); - return createInstance(this.byIndex, byName, this.sqlSupplier); + return createInstance(this.byIndex, byName, this.sqlSupplier, this.filterFunction); + } + + public ExecuteSpecSupport filter(StatementFilterFunction filter) { + + Assert.notNull(filter, "Statement FilterFunction must not be null!"); + + return createInstance(this.byIndex, byName, this.sqlSupplier, this.filterFunction.andThen(filter)); } private void assertNotPreparedOperation() { @@ -476,8 +490,8 @@ private void assertNotPreparedOperation() { } protected ExecuteSpecSupport createInstance(Map byIndex, Map byName, - Supplier sqlSupplier) { - return new ExecuteSpecSupport(byIndex, byName, sqlSupplier); + Supplier sqlSupplier, StatementFilterFunction filterFunction) { + return new ExecuteSpecSupport(byIndex, byName, sqlSupplier, filterFunction); } } @@ -487,8 +501,8 @@ protected ExecuteSpecSupport createInstance(Map byIndex, protected class DefaultGenericExecuteSpec extends ExecuteSpecSupport implements GenericExecuteSpec { DefaultGenericExecuteSpec(Map byIndex, Map byName, - Supplier sqlSupplier) { - super(byIndex, byName, sqlSupplier); + Supplier sqlSupplier, StatementFilterFunction filterFunction) { + super(byIndex, byName, sqlSupplier, filterFunction); } DefaultGenericExecuteSpec(Supplier sqlSupplier) { @@ -500,7 +514,7 @@ public TypedExecuteSpec as(Class resultType) { Assert.notNull(resultType, "Result type must not be null!"); - return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, resultType); + return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, this.filterFunction, resultType); } @Override @@ -549,10 +563,15 @@ public DefaultGenericExecuteSpec bindNull(String name, Class type) { return (DefaultGenericExecuteSpec) super.bindNull(name, type); } + @Override + public DefaultGenericExecuteSpec filter(StatementFilterFunction filter) { + return (DefaultGenericExecuteSpec) super.filter(filter); + } + @Override protected ExecuteSpecSupport createInstance(Map byIndex, Map byName, - Supplier sqlSupplier) { - return createGenericExecuteSpec(byIndex, byName, sqlSupplier); + Supplier sqlSupplier, StatementFilterFunction filterFunction) { + return createGenericExecuteSpec(byIndex, byName, sqlSupplier, filterFunction); } } @@ -566,9 +585,9 @@ protected class DefaultTypedExecuteSpec extends ExecuteSpecSupport implements private final BiFunction mappingFunction; DefaultTypedExecuteSpec(Map byIndex, Map byName, - Supplier sqlSupplier, Class typeToRead) { + Supplier sqlSupplier, StatementFilterFunction filterFunction, Class typeToRead) { - super(byIndex, byName, sqlSupplier); + super(byIndex, byName, sqlSupplier, filterFunction); this.typeToRead = typeToRead; @@ -581,9 +600,10 @@ protected class DefaultTypedExecuteSpec extends ExecuteSpecSupport implements } DefaultTypedExecuteSpec(Map byIndex, Map byName, - Supplier sqlSupplier, BiFunction mappingFunction) { + Supplier sqlSupplier, StatementFilterFunction filterFunction, + BiFunction mappingFunction) { - super(byIndex, byName, sqlSupplier); + super(byIndex, byName, sqlSupplier, filterFunction); this.typeToRead = null; this.mappingFunction = mappingFunction; @@ -594,7 +614,7 @@ public TypedExecuteSpec as(Class resultType) { Assert.notNull(resultType, "Result type must not be null!"); - return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, resultType); + return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, this.filterFunction, resultType); } @Override @@ -643,10 +663,15 @@ public DefaultTypedExecuteSpec bindNull(String name, Class type) { return (DefaultTypedExecuteSpec) super.bindNull(name, type); } + @Override + public DefaultTypedExecuteSpec filter(StatementFilterFunction filter) { + return (DefaultTypedExecuteSpec) super.filter(filter); + } + @Override protected DefaultTypedExecuteSpec createInstance(Map byIndex, - Map byName, Supplier sqlSupplier) { - return createTypedExecuteSpec(byIndex, byName, sqlSupplier, this.typeToRead); + Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction) { + return createTypedExecuteSpec(byIndex, byName, sqlSupplier, filterFunction, this.typeToRead); } } @@ -735,7 +760,8 @@ FetchSpec execute(PreparedOperation preparedOperation, BiFunction selectFunction = wrapPreparedOperation(sql, preparedOperation); - Function> resultFunction = DefaultDatabaseClient.toExecuteFunction(sql, selectFunction); + Function> resultFunction = toFunction(sql, StatementFilterFunctions.empty(), + selectFunction); return new DefaultSqlResult<>(DefaultDatabaseClient.this, // sql, // @@ -1432,7 +1458,8 @@ private FetchSpec exchangeInsert(BiFunction mappingF String sql = getRequiredSql(operation); Function insertFunction = wrapPreparedOperation(sql, operation) .andThen(statement -> statement.returnGeneratedValues()); - Function> resultFunction = toExecuteFunction(sql, insertFunction); + Function> resultFunction = toFunction(sql, StatementFilterFunctions.empty(), + insertFunction); return new DefaultSqlResult<>(this, // sql, // @@ -1445,7 +1472,8 @@ private UpdatedRowsFetchSpec exchangeUpdate(PreparedOperation operation) { String sql = getRequiredSql(operation); Function executeFunction = wrapPreparedOperation(sql, operation); - Function> resultFunction = toExecuteFunction(sql, executeFunction); + Function> resultFunction = toFunction(sql, StatementFilterFunctions.empty(), + executeFunction); return new DefaultSqlResult<>(this, // sql, // @@ -1476,12 +1504,15 @@ private Function wrapPreparedOperation(String sql, Prepar }; } - private static Function> toExecuteFunction(String sql, - Function executeFunction) { + private Function> toFunction(String sql, StatementFilterFunction filterFunction, + Function statementFactory) { return it -> { - Flux from = Flux.defer(() -> executeFunction.apply(it).execute()).cast(Result.class); + Flux from = Flux.defer(() -> { + Statement statement = statementFactory.apply(it); + return filterFunction.filter(statement, executeFunction); + }).cast(Result.class); return from.checkpoint("SQL \"" + sql + "\" [DatabaseClient]"); }; } diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java index c3a186da..5f08ab95 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java @@ -17,6 +17,7 @@ package org.springframework.data.r2dbc.core; import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Statement; import java.util.function.Consumer; @@ -40,6 +41,8 @@ class DefaultDatabaseClientBuilder implements DatabaseClient.Builder { private @Nullable R2dbcExceptionTranslator exceptionTranslator; + private ExecuteFunction executeFunction = Statement::execute; + private ReactiveDataAccessStrategy accessStrategy; private boolean namedParameters = true; @@ -54,6 +57,7 @@ class DefaultDatabaseClientBuilder implements DatabaseClient.Builder { this.connectionFactory = other.connectionFactory; this.exceptionTranslator = other.exceptionTranslator; + this.executeFunction = other.executeFunction; this.accessStrategy = other.accessStrategy; this.namedParameters = other.namedParameters; this.projectionFactory = other.projectionFactory; @@ -85,6 +89,19 @@ public Builder exceptionTranslator(R2dbcExceptionTranslator exceptionTranslator) return this; } + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.function.DatabaseClient.Builder#executeFunction(org.springframework.data.r2dbc.core.ExecuteFunction) + */ + @Override + public Builder executeFunction(ExecuteFunction executeFunction) { + + Assert.notNull(executeFunction, "ExecuteFunction must not be null!"); + + this.executeFunction = executeFunction; + return this; + } + /* * (non-Javadoc) * @see org.springframework.data.r2dbc.function.DatabaseClient.Builder#dataAccessStrategy(org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy) @@ -143,8 +160,8 @@ public DatabaseClient build() { accessStrategy = new DefaultReactiveDataAccessStrategy(dialect); } - return new DefaultDatabaseClient(this.connectionFactory, exceptionTranslator, accessStrategy, namedParameters, - projectionFactory, new DefaultDatabaseClientBuilder(this)); + return new DefaultDatabaseClient(this.connectionFactory, exceptionTranslator, executeFunction, accessStrategy, + namedParameters, projectionFactory, new DefaultDatabaseClientBuilder(this)); } /* diff --git a/src/main/java/org/springframework/data/r2dbc/core/ExecuteFunction.java b/src/main/java/org/springframework/data/r2dbc/core/ExecuteFunction.java new file mode 100644 index 00000000..773916da --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ExecuteFunction.java @@ -0,0 +1,46 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; + +import java.util.function.BiFunction; + +import org.reactivestreams.Publisher; + +/** + * Represents a function that executes a {@link io.r2dbc.spi.Statement} for a (delayed) {@link io.r2dbc.spi.Result} + * stream. + *

+ * Note that discarded {@link Result} objects must be consumed according to the R2DBC spec via either + * {@link Result#getRowsUpdated()} or {@link Result#map(BiFunction)}. + * + * @author Mark Paluch + * @since 1.1 + * @see Statement#execute() + */ +@FunctionalInterface +public interface ExecuteFunction { + + /** + * Execute the given {@link Statement} for a stream of {@link Result}s. + * + * @param statement the request to execute. + * @return the delayed result stream. + */ + Publisher execute(Statement statement); +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java new file mode 100644 index 00000000..c5a271f7 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java @@ -0,0 +1,65 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; + +import org.reactivestreams.Publisher; + +import org.springframework.util.Assert; + +/** + * Represents a function that filters an {@link ExecuteFunction execute function}. + *

+ * The filter is executed when a {@link org.reactivestreams.Subscriber} subscribes to the {@link Publisher} returned by + * the {@link DatabaseClient}. + * + * @author Mark Paluch + * @since 1.1 + * @see ExecuteFunction + */ +@FunctionalInterface +public interface StatementFilterFunction { + + /** + * Apply this filter to the given {@link Statement} and {@link ExecuteFunction}. + *

+ * The given {@link ExecuteFunction} represents the next entity in the chain, to be invoked via + * {@link ExecuteFunction#execute(Statement)} invoked} in order to proceed with the exchange, or not invoked to + * shortcut the chain. + * + * @param statement the current {@link Statement}. + * @param next the next exchange function in the chain. + * @return the filtered {@link Result}s. + */ + Publisher filter(Statement statement, ExecuteFunction next); + + /** + * Return a composed filter function that first applies this filter, and then applies the given {@code "after"} + * filter. + * + * @param afterFilter the filter to apply after this filter. + * @return the composed filter. + */ + default StatementFilterFunction andThen(StatementFilterFunction afterFilter) { + + Assert.notNull(afterFilter, "StatementFilterFunction must not be null"); + + return (request, next) -> filter(request, afterRequest -> afterFilter.filter(afterRequest, next)); + } + +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunctions.java b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunctions.java new file mode 100644 index 00000000..e9788992 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunctions.java @@ -0,0 +1,46 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; + +import org.reactivestreams.Publisher; + +/** + * Collection of default {@link StatementFilterFunction}s. + * + * @author Mark Paluch + * @since 1.1 + */ +enum StatementFilterFunctions implements StatementFilterFunction { + + EMPTY_FILTER; + + @Override + public Publisher filter(Statement statement, ExecuteFunction next) { + return next.execute(statement); + } + + /** + * Return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}. + * + * @return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}. + */ + public static StatementFilterFunction empty() { + return EMPTY_FILTER; + } +} diff --git a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java index 5c88fec9..de571c97 100644 --- a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -37,6 +37,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.reactivestreams.Publisher; @@ -468,7 +469,120 @@ public void shouldProjectTypedSelectAs() { }) // .verifyComplete(); + } + + @Test // gh-189 + public void shouldApplyExecuteFunction() { + + Statement statement = mock(Statement.class); + when(connection.createStatement(anyString())).thenReturn(statement); + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified(0, Object.class, "Walter").build()).build(); + + DatabaseClient databaseClient = DatabaseClient.builder() // + .connectionFactory(connectionFactory) // + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // + .executeFunction(it -> Mono.just(result)).build(); + + databaseClient.execute("SELECT") // + .fetch().all() // + .as(StepVerifier::create) // + .expectNextCount(1).verifyComplete(); + + verify(statement, never()).execute(); + } + + @Test // gh-189 + public void shouldApplyStatementFilterFunctions() { + + Statement statement = mock(Statement.class); + when(connection.createStatement(anyString())).thenReturn(statement); + when(statement.returnGeneratedValues(anyString())).thenReturn(statement); + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata).build(); + + doReturn(Flux.just(result)).when(statement).execute(); + + DatabaseClient databaseClient = DatabaseClient.builder() // + .connectionFactory(connectionFactory) // + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // + .build(); + + databaseClient.execute("SELECT") // + .filter((s, next) -> next.execute(s.returnGeneratedValues("foo"))) // + .filter((s, next) -> next.execute(s.returnGeneratedValues("bar"))) // + .fetch().all() // + .as(StepVerifier::create) // + .verifyComplete(); + + InOrder inOrder = inOrder(statement); + inOrder.verify(statement).returnGeneratedValues("foo"); + inOrder.verify(statement).returnGeneratedValues("bar"); + inOrder.verify(statement).execute(); + } + + @Test // gh-189 + public void shouldApplyStatementFilterFunctionsToTypedExecute() { + + Statement statement = mock(Statement.class); + when(connection.createStatement(anyString())).thenReturn(statement); + when(statement.returnGeneratedValues(anyString())).thenReturn(statement); + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata).build(); + + doReturn(Flux.just(result)).when(statement).execute(); + + DatabaseClient databaseClient = DatabaseClient.builder() // + .connectionFactory(connectionFactory) // + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // + .build(); + + databaseClient.execute("SELECT") // + .filter((s, next) -> next.execute(s.returnGeneratedValues("foo"))) // + .as(Person.class) // + .fetch().all() // + .as(StepVerifier::create) // + .verifyComplete(); + + verify(statement).returnGeneratedValues("foo"); + } + + @Test // gh-189 + public void shouldApplySimpleStatementFilterFunctions() { + + Statement statement = mock(Statement.class); + when(connection.createStatement(anyString())).thenReturn(statement); + when(statement.returnGeneratedValues(anyString())).thenReturn(statement); + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata).build(); + + doReturn(Flux.just(result)).when(statement).execute(); + + DatabaseClient databaseClient = DatabaseClient.builder() // + .connectionFactory(connectionFactory) // + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // + .build(); + + databaseClient.execute("SELECT") // + .filter(s -> s.returnGeneratedValues("foo")) // + .filter(s -> s.returnGeneratedValues("bar")) // + .fetch().all() // + .as(StepVerifier::create) // + .verifyComplete(); + InOrder inOrder = inOrder(statement); + inOrder.verify(statement).returnGeneratedValues("foo"); + inOrder.verify(statement).returnGeneratedValues("bar"); + inOrder.verify(statement).execute(); } static class Person { From c3fb8b31d5c64da40405f6faa25438ac7c3654e8 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Mon, 24 Feb 2020 15:41:51 +0100 Subject: [PATCH 3/5] #189 - Polishing. Made assertions in tests more strict. --- .../data/r2dbc/core/DefaultDatabaseClient.java | 1 + .../data/r2dbc/core/DefaultDatabaseClientUnitTests.java | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java index f95b356d..37777d83 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java @@ -343,6 +343,7 @@ class ExecuteSpecSupport { ExecuteSpecSupport(Map byIndex, Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction) { + this.byIndex = byIndex; this.byName = byName; this.sqlSupplier = sqlSupplier; diff --git a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java index de571c97..bdfef8ea 100644 --- a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -492,7 +492,7 @@ public void shouldApplyExecuteFunction() { .as(StepVerifier::create) // .expectNextCount(1).verifyComplete(); - verify(statement, never()).execute(); + verifyNoInteractions(statement); } @Test // gh-189 @@ -524,6 +524,7 @@ public void shouldApplyStatementFilterFunctions() { inOrder.verify(statement).returnGeneratedValues("foo"); inOrder.verify(statement).returnGeneratedValues("bar"); inOrder.verify(statement).execute(); + inOrder.verifyNoMoreInteractions(); } @Test // gh-189 @@ -551,7 +552,10 @@ public void shouldApplyStatementFilterFunctionsToTypedExecute() { .as(StepVerifier::create) // .verifyComplete(); - verify(statement).returnGeneratedValues("foo"); + InOrder inOrder = inOrder(statement); + inOrder.verify(statement).returnGeneratedValues("foo"); + inOrder.verify(statement).execute(); + inOrder.verifyNoMoreInteractions(); } @Test // gh-189 @@ -583,6 +587,7 @@ public void shouldApplySimpleStatementFilterFunctions() { inOrder.verify(statement).returnGeneratedValues("foo"); inOrder.verify(statement).returnGeneratedValues("bar"); inOrder.verify(statement).execute(); + inOrder.verifyNoMoreInteractions(); } static class Person { From 46e3fb6c3baec8339341d3bda7c89a81e593acf1 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Tue, 25 Feb 2020 11:17:40 +0100 Subject: [PATCH 4/5] #189 - Polishing. Refactored DefaultDatabaseClientUnitTests in order to make the relevant differences in setup easier to spot. --- .../core/DefaultDatabaseClientUnitTests.java | 266 +++++++----------- 1 file changed, 101 insertions(+), 165 deletions(-) diff --git a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java index bdfef8ea..ba8b4372 100644 --- a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -39,47 +39,49 @@ import org.junit.runner.RunWith; import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; - import org.springframework.beans.factory.annotation.Value; import org.springframework.data.annotation.Id; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.mapping.SettableValue; -import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; +import org.springframework.lang.Nullable; /** * Unit tests for {@link DefaultDatabaseClient}. * * @author Mark Paluch * @author Ferdinand Jacobs + * @author Jens Schauder */ @RunWith(MockitoJUnitRunner.class) public class DefaultDatabaseClientUnitTests { - @Mock ConnectionFactory connectionFactory; @Mock Connection connection; - @Mock R2dbcExceptionTranslator translator; + private DatabaseClient.Builder databaseClientBuilder; @Before public void before() { + + ConnectionFactory connectionFactory = Mockito.mock(ConnectionFactory.class); + when(connectionFactory.create()).thenReturn((Publisher) Mono.just(connection)); when(connection.close()).thenReturn(Mono.empty()); + + databaseClientBuilder = DatabaseClient.builder() // + .connectionFactory(connectionFactory) // + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)); } @Test // gh-48 public void shouldCloseConnectionOnlyOnce() { - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) - .exceptionTranslator(translator).build(); + DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) databaseClientBuilder.build(); - Flux flux = databaseClient.inConnectionMany(it -> { - return Flux.empty(); - }); + Flux flux = databaseClient.inConnectionMany(it -> Flux.empty()); flux.subscribe(new CoreSubscriber() { Subscription subscription; @@ -108,13 +110,9 @@ public void onComplete() { @Test // gh-128 public void executeShouldBindNullValues() { - Statement statement = mock(Statement.class); - when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT * FROM table WHERE key = $1") // .bindNull(0, String.class) // @@ -136,13 +134,9 @@ public void executeShouldBindNullValues() { @Test // gh-162 public void executeShouldBindSettableValues() { - Statement statement = mock(Statement.class); - when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT * FROM table WHERE key = $1") // .bind(0, SettableValue.empty(String.class)) // @@ -164,13 +158,8 @@ public void executeShouldBindSettableValues() { @Test // gh-128 public void executeShouldBindNamedNullValues() { - Statement statement = mock(Statement.class); - when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); - - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT * FROM table WHERE key = :key") // .bindNull("key", String.class) // @@ -184,14 +173,9 @@ public void executeShouldBindNamedNullValues() { @Test // gh-178 public void executeShouldBindNamedValuesFromIndexes() { - Statement statement = mock(Statement.class); - when(connection.createStatement("SELECT id, name, manual FROM legoset WHERE name IN ($1, $2, $3)")) - .thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); + Statement statement = mockStatementFor("SELECT id, name, manual FROM legoset WHERE name IN ($1, $2, $3)"); - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT id, name, manual FROM legoset WHERE name IN (:name)") // .bind(0, Arrays.asList("unknown", "dunno", "other")) // @@ -209,13 +193,9 @@ public void executeShouldBindNamedValuesFromIndexes() { @Test // gh-128, gh-162 public void executeShouldBindValues() { - Statement statement = mock(Statement.class); - when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT * FROM table WHERE key = $1") // .bind(0, SettableValue.from("foo")) // @@ -237,14 +217,8 @@ public void executeShouldBindValues() { @Test // gh-162 public void insertShouldAcceptNullValues() { - Statement statement = mock(Statement.class); - when(connection.createStatement("INSERT INTO foo (first, second) VALUES ($1, $2)")).thenReturn(statement); - when(statement.returnGeneratedValues()).thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); - - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + Statement statement = mockStatementFor("INSERT INTO foo (first, second) VALUES ($1, $2)"); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.insert().into("foo") // .value("first", "foo") // @@ -260,14 +234,8 @@ public void insertShouldAcceptNullValues() { @Test // gh-162 public void insertShouldAcceptSettableValue() { - Statement statement = mock(Statement.class); - when(connection.createStatement("INSERT INTO foo (first, second) VALUES ($1, $2)")).thenReturn(statement); - when(statement.returnGeneratedValues()).thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); - - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + Statement statement = mockStatementFor("INSERT INTO foo (first, second) VALUES ($1, $2)"); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.insert().into("foo") // .value("first", SettableValue.from("foo")) // @@ -283,13 +251,8 @@ public void insertShouldAcceptSettableValue() { @Test // gh-128 public void executeShouldBindNamedValuesByIndex() { - Statement statement = mock(Statement.class); - when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); - - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT * FROM table WHERE key = :key") // .bind("key", "foo") // @@ -303,14 +266,8 @@ public void executeShouldBindNamedValuesByIndex() { @Test // gh-177 public void deleteNotInShouldRenderCorrectQuery() { - Statement statement = mock(Statement.class); - when(connection.createStatement("DELETE FROM tab WHERE tab.pole = $1 AND tab.id NOT IN ($2, $3)")) - .thenReturn(statement); - when(statement.execute()).thenReturn(Mono.empty()); - - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + Statement statement = mockStatementFor("DELETE FROM tab WHERE tab.pole = $1 AND tab.id NOT IN ($2, $3)"); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.delete().from("tab").matching(where("pole").is("foo").and("id").notIn(1, 2)) // .then() // @@ -318,23 +275,18 @@ public void deleteNotInShouldRenderCorrectQuery() { .verifyComplete(); verify(statement).bind(0, "foo"); - verify(statement).bind(1, (Object) 1); - verify(statement).bind(2, (Object) 2); + verify(statement).bind(1, 1); + verify(statement).bind(2, 2); } @Test // gh-243 public void rowsUpdatedShouldEmitSingleValue() { - Statement statement = mock(Statement.class); - when(connection.createStatement("DROP TABLE tab;")).thenReturn(statement); Result result = mock(Result.class); - doReturn(Flux.just(result)).when(statement).execute(); - - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); - when(result.getRowsUpdated()).thenReturn(Mono.empty(), Mono.just(2), Flux.just(1, 2, 3)); + mockStatementFor("DROP TABLE tab;", result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("DROP TABLE tab;") // .fetch() // @@ -361,10 +313,7 @@ public void rowsUpdatedShouldEmitSingleValue() { @Test // gh-250 public void shouldThrowExceptionForSingleColumnObjectUpdate() { - DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // - .build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); assertThatIllegalArgumentException().isThrownBy(() -> databaseClient.update() // .table(IdOnly.class) // @@ -375,20 +324,11 @@ public void shouldThrowExceptionForSingleColumnObjectUpdate() { @Test // gh-260 public void shouldProjectGenericExecuteAs() { - Statement statement = mock(Statement.class); - when(connection.createStatement(anyString())).thenReturn(statement); - - MockRowMetadata metadata = MockRowMetadata.builder() - .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); - MockResult result = MockResult.builder().rowMetadata(metadata) - .row(MockRow.builder().identified(0, Object.class, "Walter").build()).build(); - - doReturn(Flux.just(result)).when(statement).execute(); + MockResult result = mockSingleColumnResult(MockRow.builder().identified(0, Object.class, "Walter")); + mockStatement(result); - DatabaseClient databaseClient = DatabaseClient.builder() // - .connectionFactory(connectionFactory) // + DatabaseClient databaseClient = databaseClientBuilder // .projectionFactory(new SpelAwareProxyProjectionFactory()) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // .build(); databaseClient.execute("SELECT * FROM person") // @@ -408,20 +348,11 @@ public void shouldProjectGenericExecuteAs() { @Test // gh-260 public void shouldProjectGenericSelectAs() { - Statement statement = mock(Statement.class); - when(connection.createStatement(anyString())).thenReturn(statement); + MockResult result = mockSingleColumnResult(MockRow.builder().identified(0, Object.class, "Walter")); + mockStatement(result); - MockRowMetadata metadata = MockRowMetadata.builder() - .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); - MockResult result = MockResult.builder().rowMetadata(metadata) - .row(MockRow.builder().identified(0, Object.class, "Walter").build()).build(); - - doReturn(Flux.just(result)).when(statement).execute(); - - DatabaseClient databaseClient = DatabaseClient.builder() // - .connectionFactory(connectionFactory) // + DatabaseClient databaseClient = databaseClientBuilder // .projectionFactory(new SpelAwareProxyProjectionFactory()) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // .build(); databaseClient.select().from("person") // @@ -442,20 +373,11 @@ public void shouldProjectGenericSelectAs() { @Test // gh-260 public void shouldProjectTypedSelectAs() { - Statement statement = mock(Statement.class); - when(connection.createStatement(anyString())).thenReturn(statement); + MockResult result = mockSingleColumnResult(MockRow.builder().identified("name", Object.class, "Walter")); + mockStatement(result); - MockRowMetadata metadata = MockRowMetadata.builder() - .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); - MockResult result = MockResult.builder().rowMetadata(metadata) - .row(MockRow.builder().identified("name", Object.class, "Walter").build()).build(); - - doReturn(Flux.just(result)).when(statement).execute(); - - DatabaseClient databaseClient = DatabaseClient.builder() // - .connectionFactory(connectionFactory) // + DatabaseClient databaseClient = databaseClientBuilder // .projectionFactory(new SpelAwareProxyProjectionFactory()) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // .build(); databaseClient.select().from(Person.class) // @@ -474,18 +396,12 @@ public void shouldProjectTypedSelectAs() { @Test // gh-189 public void shouldApplyExecuteFunction() { - Statement statement = mock(Statement.class); - when(connection.createStatement(anyString())).thenReturn(statement); + Statement statement = mockStatement(); + MockResult result = mockSingleColumnResult(MockRow.builder().identified(0, Object.class, "Walter")); - MockRowMetadata metadata = MockRowMetadata.builder() - .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); - MockResult result = MockResult.builder().rowMetadata(metadata) - .row(MockRow.builder().identified(0, Object.class, "Walter").build()).build(); - - DatabaseClient databaseClient = DatabaseClient.builder() // - .connectionFactory(connectionFactory) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // - .executeFunction(it -> Mono.just(result)).build(); + DatabaseClient databaseClient = databaseClientBuilder // + .executeFunction(it -> Mono.just(result)) // + .build(); databaseClient.execute("SELECT") // .fetch().all() // @@ -498,20 +414,13 @@ public void shouldApplyExecuteFunction() { @Test // gh-189 public void shouldApplyStatementFilterFunctions() { - Statement statement = mock(Statement.class); - when(connection.createStatement(anyString())).thenReturn(statement); - when(statement.returnGeneratedValues(anyString())).thenReturn(statement); - MockRowMetadata metadata = MockRowMetadata.builder() .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); MockResult result = MockResult.builder().rowMetadata(metadata).build(); - doReturn(Flux.just(result)).when(statement).execute(); + Statement statement = mockStatement(result); - DatabaseClient databaseClient = DatabaseClient.builder() // - .connectionFactory(connectionFactory) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // - .build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT") // .filter((s, next) -> next.execute(s.returnGeneratedValues("foo"))) // @@ -530,20 +439,13 @@ public void shouldApplyStatementFilterFunctions() { @Test // gh-189 public void shouldApplyStatementFilterFunctionsToTypedExecute() { - Statement statement = mock(Statement.class); - when(connection.createStatement(anyString())).thenReturn(statement); - when(statement.returnGeneratedValues(anyString())).thenReturn(statement); - MockRowMetadata metadata = MockRowMetadata.builder() .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); MockResult result = MockResult.builder().rowMetadata(metadata).build(); - doReturn(Flux.just(result)).when(statement).execute(); + Statement statement = mockStatement(result); - DatabaseClient databaseClient = DatabaseClient.builder() // - .connectionFactory(connectionFactory) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // - .build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT") // .filter((s, next) -> next.execute(s.returnGeneratedValues("foo"))) // @@ -561,20 +463,11 @@ public void shouldApplyStatementFilterFunctionsToTypedExecute() { @Test // gh-189 public void shouldApplySimpleStatementFilterFunctions() { - Statement statement = mock(Statement.class); - when(connection.createStatement(anyString())).thenReturn(statement); - when(statement.returnGeneratedValues(anyString())).thenReturn(statement); + MockResult result = mockSingleColumnEmptyResult(); - MockRowMetadata metadata = MockRowMetadata.builder() - .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); - MockResult result = MockResult.builder().rowMetadata(metadata).build(); - - doReturn(Flux.just(result)).when(statement).execute(); + Statement statement = mockStatement(result); - DatabaseClient databaseClient = DatabaseClient.builder() // - .connectionFactory(connectionFactory) // - .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) // - .build(); + DatabaseClient databaseClient = databaseClientBuilder.build(); databaseClient.execute("SELECT") // .filter(s -> s.returnGeneratedValues("foo")) // @@ -590,6 +483,49 @@ public void shouldApplySimpleStatementFilterFunctions() { inOrder.verifyNoMoreInteractions(); } + private Statement mockStatement() { + return mockStatementFor(null, null); + } + + private Statement mockStatement(Result result) { + return mockStatementFor(null, result); + } + + private Statement mockStatementFor(String sql) { + return mockStatementFor(sql, null); + } + + private Statement mockStatementFor(@Nullable String sql, @Nullable Result result) { + + Statement statement = mock(Statement.class); + when(connection.createStatement(sql == null ? anyString() : eq(sql))).thenReturn(statement); + when(statement.returnGeneratedValues(anyString())).thenReturn(statement); + when(statement.returnGeneratedValues()).thenReturn(statement); + + doReturn(result == null ? Mono.empty() : Flux.just(result)).when(statement).execute(); + + return statement; + } + + private MockResult mockSingleColumnEmptyResult() { + return mockSingleColumnResult(null); + } + + /** + * Mocks a {@link Result} with a single column "name" and a single row if a non null row is provided. + */ + private MockResult mockSingleColumnResult(@Nullable MockRow.Builder row) { + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + + MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata); + if (row != null) { + resultBuilder = resultBuilder.row(row.build()); + } + return resultBuilder.build(); + } + static class Person { String name; From 5c078dbd4144e28b5cb931cee64a80ce5cf3e745 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 25 Feb 2020 14:30:16 +0100 Subject: [PATCH 5/5] #189 - Incorporate review feedback. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix nullability annotations. Relax generics at DatabaseClient.StatementFilterSpec.filter(…). --- .../data/r2dbc/core/DatabaseClient.java | 3 +- .../r2dbc/core/DefaultDatabaseClient.java | 68 +++++++------------ .../r2dbc/core/StatementFilterFunction.java | 1 - 3 files changed, 26 insertions(+), 46 deletions(-) diff --git a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java index 613917ef..6e090499 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java @@ -27,7 +27,6 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; -import java.util.function.UnaryOperator; import org.reactivestreams.Publisher; @@ -892,7 +891,7 @@ interface StatementFilterSpec> { * * @param filter the filter to be added to the chain. */ - default S filter(UnaryOperator filter) { + default S filter(Function filter) { Assert.notNull(filter, "Statement FilterFunction must not be null!"); diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java index 37777d83..f4204c9c 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java @@ -272,15 +272,6 @@ protected DefaultTypedExecuteSpec createTypedExecuteSpec(Map(byIndex, byName, sqlSupplier, filterFunction, typeToRead); } - /** - * Customization hook. - */ - protected DefaultTypedExecuteSpec createTypedExecuteSpec(Map byIndex, - Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction, - BiFunction mappingFunction) { - return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, filterFunction, mappingFunction); - } - /** * Customization hook. */ @@ -354,7 +345,7 @@ FetchSpec exchange(Supplier sqlSupplier, BiFunction executeFunction = it -> { + Function statementFactory = it -> { if (logger.isDebugEnabled()) { logger.debug("Executing SQL statement [" + sql + "]"); @@ -412,7 +403,7 @@ FetchSpec exchange(Supplier sqlSupplier, BiFunction> resultFunction = toFunction(sql, filterFunction, executeFunction); + Function> resultFunction = toFunction(sql, filterFunction, statementFactory); return new DefaultSqlResult<>(DefaultDatabaseClient.this, // sql, // @@ -582,7 +573,7 @@ protected ExecuteSpecSupport createInstance(Map byIndex, @SuppressWarnings("unchecked") protected class DefaultTypedExecuteSpec extends ExecuteSpecSupport implements TypedExecuteSpec { - private final @Nullable Class typeToRead; + private final Class typeToRead; private final BiFunction mappingFunction; DefaultTypedExecuteSpec(Map byIndex, Map byName, @@ -600,16 +591,6 @@ protected class DefaultTypedExecuteSpec extends ExecuteSpecSupport implements } } - DefaultTypedExecuteSpec(Map byIndex, Map byName, - Supplier sqlSupplier, StatementFilterFunction filterFunction, - BiFunction mappingFunction) { - - super(byIndex, byName, sqlSupplier, filterFunction); - - this.typeToRead = null; - this.mappingFunction = mappingFunction; - } - @Override public TypedExecuteSpec as(Class resultType) { @@ -717,8 +698,8 @@ private abstract class DefaultSelectSpecSupport { this.page = Pageable.unpaged(); } - DefaultSelectSpecSupport(SqlIdentifier table, List projectedFields, Criteria criteria, Sort sort, - Pageable page) { + DefaultSelectSpecSupport(SqlIdentifier table, List projectedFields, @Nullable Criteria criteria, + Sort sort, Pageable page) { this.table = table; this.projectedFields = projectedFields; this.criteria = criteria; @@ -772,13 +753,13 @@ FetchSpec execute(PreparedOperation preparedOperation, BiFunction projectedFields, - Criteria criteria, Sort sort, Pageable page); + @Nullable Criteria criteria, Sort sort, Pageable page); } private class DefaultGenericSelectSpec extends DefaultSelectSpecSupport implements GenericSelectSpec { - DefaultGenericSelectSpec(SqlIdentifier table, List projectedFields, Criteria criteria, Sort sort, - Pageable page) { + DefaultGenericSelectSpec(SqlIdentifier table, List projectedFields, @Nullable Criteria criteria, + Sort sort, Pageable page) { super(table, projectedFields, criteria, sort, page); } @@ -861,7 +842,7 @@ private FetchSpec exchange(BiFunction mappingFunctio @Override protected DefaultGenericSelectSpec createInstance(SqlIdentifier table, List projectedFields, - Criteria criteria, Sort sort, Pageable page) { + @Nullable Criteria criteria, Sort sort, Pageable page) { return new DefaultGenericSelectSpec(table, projectedFields, criteria, sort, page); } } @@ -883,8 +864,8 @@ private class DefaultTypedSelectSpec extends DefaultSelectSpecSupport impleme this.mappingFunction = dataAccessStrategy.getRowMapper(typeToRead); } - DefaultTypedSelectSpec(SqlIdentifier table, List projectedFields, Criteria criteria, Sort sort, - Pageable page, @Nullable Class typeToRead, BiFunction mappingFunction) { + DefaultTypedSelectSpec(SqlIdentifier table, List projectedFields, @Nullable Criteria criteria, + Sort sort, Pageable page, Class typeToRead, BiFunction mappingFunction) { super(table, projectedFields, criteria, sort, page); @@ -975,7 +956,7 @@ private FetchSpec exchange(BiFunction mappingFunctio @Override protected DefaultTypedSelectSpec createInstance(SqlIdentifier table, List projectedFields, - Criteria criteria, Sort sort, Pageable page) { + @Nullable Criteria criteria, Sort sort, Pageable page) { return new DefaultTypedSelectSpec<>(table, projectedFields, criteria, sort, page, this.typeToRead, this.mappingFunction); } @@ -1223,11 +1204,11 @@ class DefaultGenericUpdateSpec implements GenericUpdateSpec, UpdateMatchingSpec private final @Nullable Class typeToUpdate; private final @Nullable SqlIdentifier table; - private final Update assignments; - private final Criteria where; + private final @Nullable Update assignments; + private final @Nullable Criteria where; - DefaultGenericUpdateSpec(@Nullable Class typeToUpdate, @Nullable SqlIdentifier table, Update assignments, - Criteria where) { + DefaultGenericUpdateSpec(@Nullable Class typeToUpdate, @Nullable SqlIdentifier table, + @Nullable Update assignments, @Nullable Criteria where) { this.typeToUpdate = typeToUpdate; this.table = table; this.assignments = assignments; @@ -1256,6 +1237,7 @@ public UpdatedRowsFetchSpec fetch() { SqlIdentifier table; if (StringUtils.isEmpty(this.table)) { + Assert.state(this.typeToUpdate != null, "Type to update must not be null!"); table = dataAccessStrategy.getTableName(this.typeToUpdate); } else { table = this.table; @@ -1277,6 +1259,7 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) { mapper = mapper.forType(this.typeToUpdate); } + Assert.state(this.assignments != null, "Update assignments must not be null!"); StatementMapper.UpdateSpec update = mapper.createUpdate(table, this.assignments); if (this.where != null) { @@ -1291,11 +1274,11 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) { class DefaultTypedUpdateSpec implements TypedUpdateSpec, UpdateSpec { - private final @Nullable Class typeToUpdate; + private final Class typeToUpdate; private final @Nullable SqlIdentifier table; - private final T objectToUpdate; + private final @Nullable T objectToUpdate; - DefaultTypedUpdateSpec(@Nullable Class typeToUpdate, @Nullable SqlIdentifier table, T objectToUpdate) { + DefaultTypedUpdateSpec(Class typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate) { this.typeToUpdate = typeToUpdate; this.table = table; this.objectToUpdate = objectToUpdate; @@ -1390,9 +1373,9 @@ class DefaultDeleteSpec implements DeleteMatchingSpec, TypedDeleteSpec { private final @Nullable Class typeToDelete; private final @Nullable SqlIdentifier table; - private final Criteria where; + private final @Nullable Criteria where; - DefaultDeleteSpec(@Nullable Class typeToDelete, @Nullable SqlIdentifier table, Criteria where) { + DefaultDeleteSpec(@Nullable Class typeToDelete, @Nullable SqlIdentifier table, @Nullable Criteria where) { this.typeToDelete = typeToDelete; this.table = table; this.where = where; @@ -1420,6 +1403,7 @@ public UpdatedRowsFetchSpec fetch() { SqlIdentifier table; if (StringUtils.isEmpty(this.table)) { + Assert.state(this.typeToDelete != null, "Type to delete must not be null!"); table = dataAccessStrategy.getTableName(this.typeToDelete); } else { table = this.table; @@ -1608,9 +1592,7 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl // Invoke method on target Connection. try { - Object retVal = method.invoke(this.target, args); - - return retVal; + return method.invoke(this.target, args); } catch (InvocationTargetException ex) { throw ex.getTargetException(); } diff --git a/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java index c5a271f7..520b7ab6 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java +++ b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java @@ -61,5 +61,4 @@ default StatementFilterFunction andThen(StatementFilterFunction afterFilter) { return (request, next) -> filter(request, afterRequest -> afterFilter.filter(afterRequest, next)); } - }