Skip to content

Commit

Permalink
Polishing.
Browse files Browse the repository at this point in the history
Simplify type and interface arrangement.

See #1601
Original pull request: #1617
  • Loading branch information
mp911de committed Sep 26, 2023
1 parent 7b27d0e commit 8fa9e3e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
import org.springframework.data.relational.core.sqlgeneration.SqlGenerator;
import org.springframework.data.relational.domain.RowDocument;
import org.springframework.data.util.Streamable;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Reads complete Aggregates from the database, by generating appropriate SQL using a {@link SingleQuerySqlGenerator}
Expand All @@ -53,13 +53,14 @@
* @author Mark Paluch
* @since 3.2
*/
class AggregateReader<T> {
class AggregateReader<T> implements PathToColumnMapping {

private final RelationalPersistentEntity<T> aggregate;
private final Table table;
private final SqlGenerator sqlGenerator;
private final JdbcConverter converter;
private final NamedParameterJdbcOperations jdbcTemplate;
private final AliasFactory aliasFactory;
private final RowDocumentResultSetExtractor extractor;

AggregateReader(Dialect dialect, JdbcConverter converter, AliasFactory aliasFactory,
Expand All @@ -70,8 +71,25 @@ class AggregateReader<T> {
this.jdbcTemplate = jdbcTemplate;
this.table = Table.create(aggregate.getQualifiedTableName());
this.sqlGenerator = new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate);
this.extractor = new RowDocumentResultSetExtractor(converter.getMappingContext(),
createPathToColumnMapping(aliasFactory));
this.aliasFactory = aliasFactory;
this.extractor = new RowDocumentResultSetExtractor(converter.getMappingContext(), this);
}

@Override
public String column(AggregatePath path) {

String alias = aliasFactory.getColumnAlias(path);

if (alias == null) {
throw new IllegalStateException(String.format("Alias for '%s' must not be null", path));
}

return alias;
}

@Override
public String keyColumn(AggregatePath path) {
return aliasFactory.getKeyAlias(path);
}

@Nullable
Expand All @@ -84,30 +102,34 @@ public T findById(Object id) {

@Nullable
public T findOne(Query query) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
Condition condition = createCondition(query, parameterSource);

return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractZeroOrOne);
}

public List<T> findAll() {
return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll);
return doFind(query, this::extractZeroOrOne);
}

public List<T> findAllById(Iterable<?> ids) {

Collection<?> identifiers = ids instanceof Collection<?> idl ? idl : Streamable.of(ids).toList();
Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers)).limit(1);
Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers));

return findAll(query);
}

@SuppressWarnings("ConstantConditions")
public List<T> findAll() {
return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll);
}

public List<T> findAll(Query query) {
return doFind(query, this::extractAll);
}

@SuppressWarnings("ConstantConditions")
private <R> R doFind(Query query, ResultSetExtractor<R> extractor) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
Condition condition = createCondition(query, parameterSource);
return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractAll);
String sql = sqlGenerator.findAll(condition);

return jdbcTemplate.query(sql, parameterSource, extractor);
}

@Nullable
Expand All @@ -128,7 +150,7 @@ private Condition createCondition(Query query, MapSqlParameterSource parameterSo
*
* @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}.
* @return a {@code List} of aggregates, fully converted.
* @throws SQLException
* @throws SQLException on underlying JDBC errors.
*/
private List<T> extractAll(ResultSet rs) throws SQLException {

Expand All @@ -146,10 +168,10 @@ private List<T> extractAll(ResultSet rs) throws SQLException {
* {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms
* to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract.
*
* @param @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}.
* @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}.
* @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is
* empty, null is returned.
* @throws SQLException
* @throws SQLException on underlying JDBC errors.
* @throws IncorrectResultSizeDataAccessException when the conversion yields more than one instance.
*/
@Nullable
Expand All @@ -167,21 +189,4 @@ private T extractZeroOrOne(ResultSet rs) throws SQLException {
return null;
}

private PathToColumnMapping createPathToColumnMapping(AliasFactory aliasFactory) {
return new PathToColumnMapping() {
@Override
public String column(AggregatePath path) {

String alias = aliasFactory.getColumnAlias(path);
Assert.notNull(alias, () -> "alias for >" + path + "< must not be null");
return alias;
}

@Override
public String keyColumn(AggregatePath path) {
return aliasFactory.getKeyAlias(path);
}
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private class RowDocumentIterator implements Iterator<RowDocument> {
*/
private boolean hasNext;

RowDocumentIterator(RelationalPersistentEntity<?> entity, ResultSet resultSet) throws SQLException {
RowDocumentIterator(RelationalPersistentEntity<?> entity, ResultSet resultSet) {

ResultSetAdapter adapter = ResultSetAdapter.INSTANCE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.data.jdbc.core.convert;

import java.util.List;
import java.util.Optional;

import org.springframework.data.domain.Pageable;
Expand Down Expand Up @@ -56,22 +57,22 @@ public <T> T findById(Object id, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType) {
public <T> List<T> findAll(Class<T> domainType) {
return getReader(domainType).findAll();
}

@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return getReader(domainType).findAllById(ids);
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
public <T> List<T> findAll(Class<T> domainType, Sort sort) {
throw new UnsupportedOperationException();
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
throw new UnsupportedOperationException();
}

Expand All @@ -81,12 +82,12 @@ public <T> Optional<T> findOne(Query query, Class<T> domainType) {
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
public <T> List<T> findAll(Query query, Class<T> domainType) {
return getReader(domainType).findAll(query);
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@
import static org.springframework.data.jdbc.testing.TestDatabaseFeatures.Feature.*;

import java.time.LocalDateTime;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;

Expand All @@ -49,6 +57,7 @@
import org.springframework.data.jdbc.core.convert.JdbcConverter;
import org.springframework.data.jdbc.testing.EnabledOnFeature;
import org.springframework.data.jdbc.testing.IntegrationTest;
import org.springframework.data.jdbc.testing.TestClass;
import org.springframework.data.jdbc.testing.TestConfiguration;
import org.springframework.data.jdbc.testing.TestDatabaseFeatures;
import org.springframework.data.mapping.context.InvalidPersistentPropertyPath;
Expand All @@ -63,6 +72,7 @@
import org.springframework.data.relational.core.query.Query;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.ContextConfiguration;

/**
* Integration tests for {@link JdbcAggregateTemplate}.
Expand Down Expand Up @@ -1927,8 +1937,8 @@ static class WithInsertOnly {
static class Config {

@Bean
Class<?> testClass() {
return JdbcAggregateTemplateIntegrationTests.class;
TestClass testClass() {
return TestClass.of(JdbcAggregateTemplateIntegrationTests.class);
}

@Bean
Expand All @@ -1938,9 +1948,11 @@ JdbcAggregateOperations operations(ApplicationEventPublisher publisher, Relation
}
}

@ContextConfiguration(classes = Config.class)
static class JdbcAggregateTemplateIntegrationTests extends AbstractJdbcAggregateTemplateIntegrationTests {}

@ActiveProfiles(value = PROFILE_SINGLE_QUERY_LOADING)
@ContextConfiguration(classes = Config.class)
static class JdbcAggregateTemplateSingleQueryLoadingIntegrationTests
extends AbstractJdbcAggregateTemplateIntegrationTests {

Expand Down

0 comments on commit 8fa9e3e

Please sign in to comment.