diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java index 49e221f6685a..cbc97ce0150f 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java @@ -167,7 +167,7 @@ interface GenericExecuteSpec { * Bind a non-{@code null} value to a parameter identified by its * {@code index}. {@code value} can be either a scalar value or {@link io.r2dbc.spi.Parameter}. * @param index zero based index to bind the parameter to - * @param value either a scalar value or {@link io.r2dbc.spi.Parameter} + * @param value either a scalar value or a {@link io.r2dbc.spi.Parameter} */ GenericExecuteSpec bind(int index, Object value); @@ -181,7 +181,7 @@ interface GenericExecuteSpec { /** * Bind a non-{@code null} value to a parameter identified by its {@code name}. * @param name the name of the parameter - * @param value the value to bind + * @param value either a scalar value or a {@link io.r2dbc.spi.Parameter} */ GenericExecuteSpec bind(String name, Object value); @@ -192,11 +192,22 @@ interface GenericExecuteSpec { */ GenericExecuteSpec bindNull(String name, Class type); + /** + * Bind the parameter values from the given source map, + * registering each as a parameter with the map key as name. + * @param source the source map of parameters, with keys as names and + * each value either a scalar value or a {@link io.r2dbc.spi.Parameter} + * @since 6.1 + * @see #bindProperties + */ + GenericExecuteSpec bindValues(Map source); + /** * Bind the bean properties or record components from the given * source object, registering each as a named parameter. * @param source the source object (a JavaBean or record) * @since 6.1 + * @see #mapProperties */ GenericExecuteSpec bindProperties(Object source); diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java index 56b28c5379c9..4921ee48eecb 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java @@ -249,6 +249,19 @@ class DefaultGenericExecuteSpec implements GenericExecuteSpec { } @SuppressWarnings("deprecation") + private Parameter resolveParameter(Object value) { + if (value instanceof Parameter param) { + return param; + } + else if (value instanceof org.springframework.r2dbc.core.Parameter param) { + Object paramValue = param.getValue(); + return (paramValue != null ? Parameters.in(paramValue) : Parameters.in(param.getType())); + } + else { + return Parameters.in(value); + } + } + @Override public DefaultGenericExecuteSpec bind(int index, Object value) { assertNotPreparedOperation(); @@ -256,16 +269,7 @@ public DefaultGenericExecuteSpec bind(int index, Object value) { "Value at index %d must not be null. Use bindNull(…) instead.", index)); Map byIndex = new LinkedHashMap<>(this.byIndex); - if (value instanceof Parameter param) { - byIndex.put(index, param); - } - else if (value instanceof org.springframework.r2dbc.core.Parameter param) { - Object pv = param.getValue(); - byIndex.put(index, (pv != null ? Parameters.in(pv) : Parameters.in(param.getType()))); - } - else { - byIndex.put(index, Parameters.in(value)); - } + byIndex.put(index, resolveParameter(value)); return new DefaultGenericExecuteSpec(byIndex, this.byName, this.sqlSupplier, this.filterFunction); } @@ -290,15 +294,7 @@ public DefaultGenericExecuteSpec bind(String name, Object value) { "Value for parameter %s must not be null. Use bindNull(…) instead.", name)); Map byName = new LinkedHashMap<>(this.byName); - if (value instanceof Parameter p) { - byName.put(name, p); - } - else if (value instanceof org.springframework.r2dbc.core.Parameter p) { - byName.put(name, p.hasValue() ? Parameters.in(p.getValue()) : Parameters.in(p.getType())); - } - else { - byName.put(name, Parameters.in(value)); - } + byName.put(name, resolveParameter(value)); return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction); } @@ -314,6 +310,17 @@ public DefaultGenericExecuteSpec bindNull(String name, Class type) { return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction); } + @Override + public GenericExecuteSpec bindValues(Map source) { + assertNotPreparedOperation(); + Assert.notNull(source, "Parameter source must not be null"); + + Map target = new LinkedHashMap<>(this.byName); + source.forEach((name, value) -> target.put(name, resolveParameter(value))); + + return new DefaultGenericExecuteSpec(this.byIndex, target, this.sqlSupplier, this.filterFunction); + } + @Override public DefaultGenericExecuteSpec bindProperties(Object source) { assertNotPreparedOperation(); diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java index 734fef3a506c..de0f3cf19adf 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java @@ -16,7 +16,10 @@ package org.springframework.r2dbc.core; +import java.util.Map; + import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Parameters; import io.r2dbc.spi.Result; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -93,6 +96,27 @@ public void executeInsert() { .verifyComplete(); } + @Test + public void executeInsertWithMap() { + DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); + + databaseClient.sql("INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)") + .bindValues(Map.of("id", 42055, + "name", Parameter.from("SCHAUFELRADBAGGER"), + "manual", Parameters.in(Integer.class))) + .fetch().rowsUpdated() + .as(StepVerifier::create) + .expectNext(1L) + .verifyComplete(); + + databaseClient.sql("SELECT id FROM legoset") + .mapValue(Integer.class) + .first() + .as(StepVerifier::create) + .assertNext(actual -> assertThat(actual).isEqualTo(42055)) + .verifyComplete(); + } + @Test public void executeInsertWithRecords() { DatabaseClient databaseClient = DatabaseClient.create(connectionFactory);