From 4834d083c4b58d056b41b547fc214920b5aed971 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 1 Oct 2024 10:28:48 +0200 Subject: [PATCH] Add `StatementFilterFunction` to `R2dbcEntityTemplate`. See #1652 --- .../data/r2dbc/core/R2dbcEntityTemplate.java | 35 +++++++++++++++---- .../core/R2dbcEntityTemplateUnitTests.java | 20 +++++++++-- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java index 2b1252d6d7..bbc55b40a4 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -112,6 +112,8 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw private @Nullable ReactiveEntityCallbacks entityCallbacks; + private Function statementFilterFunction = Function.identity(); + /** * Create a new {@link R2dbcEntityTemplate} given {@link ConnectionFactory}. * @@ -174,6 +176,19 @@ public R2dbcEntityTemplate(DatabaseClient databaseClient, ReactiveDataAccessStra this.projectionFactory = new SpelAwareProxyProjectionFactory(); } + /** + * Set a {@link Function Statement Filter Function} that is applied to every {@link Statement}. + * + * @param statementFilterFunction must not be {@literal null}. + * @since 3.4 + */ + public void setStatementFilterFunction(Function statementFilterFunction) { + + Assert.notNull(statementFilterFunction, "StatementFilterFunction must not be null"); + + this.statementFilterFunction = statementFilterFunction; + } + @Override public DatabaseClient getDatabaseClient() { return this.databaseClient; @@ -274,6 +289,7 @@ Mono doCount(Query query, Class entityClass, SqlIdentifier tableName) { PreparedOperation operation = statementMapper.getMappedObject(selectSpec); return this.databaseClient.sql(operation) // + .filter(statementFilterFunction) // .map((r, md) -> r.get(0, Long.class)) // .first() // .defaultIfEmpty(0L); @@ -302,6 +318,7 @@ Mono doExists(Query query, Class entityClass, SqlIdentifier tableNam PreparedOperation operation = statementMapper.getMappedObject(selectSpec); return this.databaseClient.sql(operation) // + .filter(statementFilterFunction) // .map((r, md) -> r) // .first() // .hasElement(); @@ -362,7 +379,7 @@ private RowsFetchSpec doSelect(Query query, Class entityType, SqlIdent PreparedOperation operation = statementMapper.getMappedObject(selectSpec); return getRowsFetchSpec( - databaseClient.sql(operation).filter(filterFunction), + databaseClient.sql(operation).filter(statementFilterFunction.andThen(filterFunction)), entityType, returnType ); @@ -397,7 +414,7 @@ Mono doUpdate(Query query, Update update, Class entityClass, SqlIdentif } PreparedOperation operation = statementMapper.getMappedObject(selectSpec); - return this.databaseClient.sql(operation).fetch().rowsUpdated(); + return this.databaseClient.sql(operation).filter(statementFilterFunction).fetch().rowsUpdated(); } @Override @@ -422,7 +439,7 @@ Mono doDelete(Query query, Class entityClass, SqlIdentifier tableName) } PreparedOperation operation = statementMapper.getMappedObject(deleteSpec); - return this.databaseClient.sql(operation).fetch().rowsUpdated().defaultIfEmpty(0L); + return this.databaseClient.sql(operation).filter(statementFilterFunction).fetch().rowsUpdated().defaultIfEmpty(0L); } // ------------------------------------------------------------------------- @@ -441,7 +458,8 @@ public RowsFetchSpec query(PreparedOperation operation, Class entit Assert.notNull(operation, "PreparedOperation must not be null"); Assert.notNull(entityClass, "Entity class must not be null"); - return new EntityCallbackAdapter<>(getRowsFetchSpec(databaseClient.sql(operation), entityClass, resultType), + return new EntityCallbackAdapter<>( + getRowsFetchSpec(databaseClient.sql(operation).filter(statementFilterFunction), entityClass, resultType), getTableNameOrEmpty(entityClass)); } @@ -451,7 +469,8 @@ public RowsFetchSpec query(PreparedOperation operation, BiFunction(databaseClient.sql(operation).map(rowMapper), SqlIdentifier.EMPTY); + return new EntityCallbackAdapter<>(databaseClient.sql(operation).filter(statementFilterFunction).map(rowMapper), + SqlIdentifier.EMPTY); } @Override @@ -462,7 +481,8 @@ public RowsFetchSpec query(PreparedOperation operation, Class entit Assert.notNull(entityClass, "Entity class must not be null"); Assert.notNull(rowMapper, "Row mapper must not be null"); - return new EntityCallbackAdapter<>(databaseClient.sql(operation).map(rowMapper), getTableNameOrEmpty(entityClass)); + return new EntityCallbackAdapter<>(databaseClient.sql(operation).filter(statementFilterFunction).map(rowMapper), + getTableNameOrEmpty(entityClass)); } // ------------------------------------------------------------------------- @@ -541,6 +561,8 @@ private Mono doInsert(T entity, SqlIdentifier tableName, OutboundRow outb return this.databaseClient.sql(operation) // .filter(statement -> { + statement = statementFilterFunction.apply(statement); + if (identifierColumns.isEmpty()) { return statement.returnGeneratedValues(); } @@ -632,6 +654,7 @@ private Mono doUpdate(T entity, SqlIdentifier tableName, RelationalPersis PreparedOperation operation = mapper.getMappedObject(updateSpec); return this.databaseClient.sql(operation) // + .filter(statementFilterFunction) // .fetch() // .rowsUpdated() // .handle((rowsUpdated, sink) -> { diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java index e654859168..f8aed4ff79 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java @@ -206,8 +206,6 @@ void shouldProjectCountResultWithoutId() { @Test // GH-469 void shouldExistsByCriteria() { - MockRowMetadata metadata = MockRowMetadata.builder() - .columnMetadata(MockColumnMetadata.builder().name("name").type(R2dbcType.VARCHAR).build()).build(); MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); recorder.addStubbing(s -> s.startsWith("SELECT"), result); @@ -654,6 +652,24 @@ void projectDtoShouldReadPropertiesOnce() { }).verifyComplete(); } + @Test // GH-1652 + void shouldConsiderFilterFunction() { + + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.setStatementFilterFunction(statement -> statement.fetchSize(10)); + entityTemplate.count(Query.empty(), Person.class) // + .as(StepVerifier::create) // + .expectNext(1L) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getFetchSize()).isEqualTo(10); + } + @ReadingConverter static class PkConverter implements Converter {