Skip to content

Commit d8e4768

Browse files
committed
Allow using embedded Id's without Embedded annotation.
Also, refactor duplications, introduce QueryAssert. See #574 Original pull request #1957
1 parent 2b92832 commit d8e4768

File tree

17 files changed

+455
-207
lines changed

17 files changed

+455
-207
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/MappingJdbcConverter.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,14 @@ private static Function<AggregatePath, Object> getWrappedValueProvider(Function<
454454
AggregatePath aggregatePath) {
455455

456456
AggregatePath idDefiningParentPath = aggregatePath.getIdDefiningParentPath();
457+
458+
if (!idDefiningParentPath.hasIdProperty()) {
459+
return ap -> {
460+
throw new IllegalStateException(
461+
"AggregatePath %s does not define an identifier property".formatted(idDefiningParentPath));
462+
};
463+
}
464+
457465
RelationalPersistentProperty idProperty = idDefiningParentPath.getRequiredIdProperty();
458466
AggregatePath idPath = idProperty.isEntity() ? idDefiningParentPath.append(idProperty) : idDefiningParentPath;
459467

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/RowDocumentExtractorSupport.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
2727
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
2828
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
29+
import org.springframework.data.relational.core.mapping.RelationalPredicates;
2930
import org.springframework.data.relational.core.sql.SqlIdentifier;
3031
import org.springframework.data.relational.domain.RowDocument;
3132
import org.springframework.lang.Nullable;
@@ -235,7 +236,7 @@ private void readEntity(RS row, RowDocument document, AggregatePath basePath,
235236

236237
AggregatePath path = basePath.append(property);
237238

238-
if (property.isEntity() && !property.isEmbedded() && (property.isCollectionLike() || property.isQualified())) {
239+
if (RelationalPredicates.isRelation(property) && (property.isCollectionLike() || property.isQualified())) {
239240

240241
readerState.put(property, new ContainerSink<>(aggregateContext, property, path));
241242
continue;

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlContext.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,21 @@ Table getTable(AggregatePath path) {
5656
}
5757

5858
Column getColumn(AggregatePath path) {
59-
60-
AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
61-
return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
59+
return getAliasedColumn(path, path.getColumnInfo());
6260
}
6361

6462
/**
6563
* A token reverse column, used in selects to identify, if an entity is present or {@literal null}.
66-
*
64+
*
6765
* @param path must not be null.
6866
* @return a {@literal Column} that is part of the effective primary key for the given path.
6967
* @since 4.0
7068
*/
7169
Column getAnyReverseColumn(AggregatePath path) {
70+
return getAliasedColumn(path, path.getTableInfo().backReferenceColumnInfos().any());
71+
}
7272

73-
AggregatePath.ColumnInfo columnInfo = path.getTableInfo().backReferenceColumnInfos().any();
73+
private Column getAliasedColumn(AggregatePath path, AggregatePath.ColumnInfo columnInfo) {
7474
return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
7575
}
7676
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java

Lines changed: 73 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515
*/
1616
package org.springframework.data.jdbc.core.convert;
1717

18-
import java.util.*;
18+
import java.util.ArrayList;
19+
import java.util.Collection;
20+
import java.util.Collections;
21+
import java.util.Comparator;
22+
import java.util.HashSet;
23+
import java.util.LinkedHashSet;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.Set;
27+
import java.util.TreeSet;
1928
import java.util.function.BiFunction;
2029
import java.util.function.Function;
2130
import java.util.function.Predicate;
@@ -38,7 +47,6 @@
3847
import org.springframework.data.relational.core.sql.*;
3948
import org.springframework.data.relational.core.sql.render.SqlRenderer;
4049
import org.springframework.data.util.Lazy;
41-
import org.springframework.data.util.Pair;
4250
import org.springframework.data.util.Predicates;
4351
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
4452
import org.springframework.lang.Nullable;
@@ -123,11 +131,11 @@ public class SqlGenerator {
123131
* @param table the table to base the select on
124132
* @param pathFilter a filter for excluding paths from the select. All paths for which the filter returns
125133
* {@literal true} will be skipped when determining columns to select.
126-
* @return A select structure suitable for constructing more specialized selects by adding conditions.
134+
* @return a select structure suitable for constructing more specialized selects by adding conditions.
127135
* @since 4.0
128136
*/
129137
public SelectBuilder.SelectWhere createSelectBuilder(Table table, Predicate<AggregatePath> pathFilter) {
130-
return createSelectBuilder(table, pathFilter, Collections.emptyList());
138+
return createSelectBuilder(table, pathFilter, Collections.emptyList(), Query.empty());
131139
}
132140

133141
/**
@@ -188,13 +196,7 @@ private Condition getSubselectCondition(AggregatePath path,
188196
AggregatePath.TableInfo parentPathTableInfo = parentPath.getTableInfo();
189197
Table subSelectTable = Table.create(parentPathTableInfo.qualifiedTableName());
190198

191-
Map<AggregatePath, Column> selectFilterColumns = new TreeMap<>();
192-
193-
// TODO: cannot we simply pass on the columnInfos?
194-
parentPathTableInfo.effectiveIdColumnInfos().forEach( //
195-
(ap, ci) -> //
196-
selectFilterColumns.put(ap, subSelectTable.column(ci.name())) //
197-
);
199+
Map<AggregatePath, Column> selectFilterColumns = parentPathTableInfo.effectiveIdColumnInfos().toMap(subSelectTable);
198200

199201
Condition innerCondition;
200202

@@ -609,29 +611,24 @@ private SelectBuilder.SelectWhere selectBuilder(Collection<SqlIdentifier> keyCol
609611
}
610612

611613
private SelectBuilder.SelectWhere selectBuilder(Collection<SqlIdentifier> keyColumns, Query query) {
612-
613-
return createSelectBuilder(getTable(), ap -> false, keyColumns);
614+
return createSelectBuilder(getTable(), ap -> false, keyColumns, query);
614615
}
615616

616617
private SelectBuilder.SelectWhere createSelectBuilder(Table table, Predicate<AggregatePath> pathFilter,
617-
Collection<SqlIdentifier> keyColumns) {
618+
Collection<SqlIdentifier> keyColumns, Query query) {
618619

619620
Projection projection = getProjection(pathFilter, keyColumns, query, table);
620621
SelectBuilder.SelectJoin baseSelect = StatementBuilder.select(projection.columns()).from(table);
621622

622-
return (SelectBuilder.SelectWhere) addJoins(baseSelect, joinTables);
623+
return (SelectBuilder.SelectWhere) addJoins(baseSelect, projection.joins());
623624
}
624625

625-
private static SelectBuilder.SelectJoin addJoins(SelectBuilder.SelectJoin baseSelect, List<Join> joinTables) {
626-
627-
for (Join join : projection.joins()) {
628-
629-
baseSelect = baseSelect.leftOuterJoin(join.joinTable).on(join.condition);
630-
}
631-
return baseSelect;
626+
private static SelectBuilder.SelectJoin addJoins(SelectBuilder.SelectJoin baseSelect, Joins joins) {
627+
return joins.reduce(baseSelect, (join, select) -> select.leftOuterJoin(join.joinTable).on(join.condition));
632628
}
633629

634-
private Projection getProjection(Predicate<AggregatePath> pathFilter, Collection<SqlIdentifier> keyColumns, Query query, Table table) {
630+
private Projection getProjection(Predicate<AggregatePath> pathFilter, Collection<SqlIdentifier> keyColumns,
631+
Query query, Table table) {
635632

636633
Set<Expression> columns = new LinkedHashSet<>();
637634
Set<Join> joins = new LinkedHashSet<>();
@@ -642,7 +639,7 @@ private Projection getProjection(Predicate<AggregatePath> pathFilter, Collection
642639
AggregatePath aggregatePath = mappingContext.getAggregatePath(
643640
mappingContext.getPersistentPropertyPath(columnName.getReference(), entity.getTypeInformation()));
644641

645-
includeColumnAndJoin(aggregatePath, joins, columns);
642+
includeColumnAndJoin(aggregatePath, pathFilter, joins, columns);
646643
} catch (InvalidPersistentPropertyPath e) {
647644
columns.add(Column.create(columnName, table));
648645
}
@@ -656,22 +653,40 @@ private Projection getProjection(Predicate<AggregatePath> pathFilter, Collection
656653
AggregatePath aggregatePath = mappingContext.getAggregatePath(path);
657654

658655
if (pathFilter.test(aggregatePath)) {
659-
continue;
660-
}
656+
continue;
657+
}
661658

662-
includeColumnAndJoin(aggregatePath, joins, columns);
659+
includeColumnAndJoin(aggregatePath, pathFilter, joins, columns);
663660
}
664661
}
665662

666663
for (SqlIdentifier keyColumn : keyColumns) {
667664
columns.add(table.column(keyColumn).as(keyColumn));
668665
}
669666

670-
return new Projection(columns, joins);
667+
return new Projection(columns, Joins.of(joins));
671668
}
672669

673-
private void includeColumnAndJoin(AggregatePath aggregatePath, Collection<Join> joins,
674-
Collection<Expression> columns) {
670+
private void includeColumnAndJoin(AggregatePath aggregatePath, Predicate<AggregatePath> pathFilter,
671+
Collection<Join> joins, Collection<Expression> columns) {
672+
673+
if (aggregatePath.isEmbedded()) {
674+
675+
RelationalPersistentEntity<?> entity = aggregatePath.getRequiredLeafEntity();
676+
677+
for (RelationalPersistentProperty property : entity) {
678+
679+
AggregatePath nested = aggregatePath.append(property);
680+
681+
if (pathFilter.test(nested)) {
682+
continue;
683+
}
684+
685+
includeColumnAndJoin(nested, pathFilter, joins, columns);
686+
}
687+
688+
return;
689+
}
675690

676691
joins.addAll(getJoins(aggregatePath));
677692

@@ -687,7 +702,24 @@ private void includeColumnAndJoin(AggregatePath aggregatePath, Collection<Join>
687702
* @param columns
688703
* @param joins
689704
*/
690-
record Projection(Set<Expression> columns, Set<Join> joins) {
705+
record Projection(Collection<Expression> columns, Joins joins) {
706+
707+
}
708+
709+
record Joins(Collection<Join> joins) {
710+
711+
public static Joins of(Collection<Join> joins) {
712+
return new Joins(joins);
713+
}
714+
715+
public <T> T reduce(T identity, BiFunction<Join, T, T> accumulator) {
716+
717+
T result = identity;
718+
for (Join join : joins) {
719+
result = accumulator.apply(join, result);
720+
}
721+
return result;
722+
}
691723
}
692724

693725
private SelectBuilder.SelectOrdered selectBuilder(Collection<SqlIdentifier> keyColumns, Sort sort,
@@ -922,11 +954,8 @@ private String createDeleteByPathAndCriteria(AggregatePath path,
922954
.from(table);
923955
Delete delete;
924956

925-
Map<AggregatePath, Column> columns = new TreeMap<>();
926957
AggregatePath.ColumnInfos columnInfos = path.getTableInfo().backReferenceColumnInfos();
927-
928-
// TODO: cannot we simply pass on the columnInfos?
929-
columnInfos.forEach((ag, ci) -> columns.put(ag, table.column(ci.name())));
958+
Map<AggregatePath, Column> columns = columnInfos.toMap(table);
930959

931960
if (isFirstNonRoot(path)) {
932961

@@ -978,22 +1007,19 @@ private Table getTable() {
9781007
* @return a single column of the primary key to be used in places where one need something not null to be selected.
9791008
*/
9801009
private Column getSingleNonNullColumn() {
981-
982-
// getColumn() is slightly different from the code in any(…). Why?
983-
// AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
984-
// return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
985-
986-
AggregatePath.ColumnInfos columnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
987-
return columnInfos.any((ap, ci) -> sqlContext.getColumn(ap));
1010+
return doGetColumn(AggregatePath.ColumnInfos::any);
9881011
}
9891012

9901013
private List<Column> getIdColumns() {
1014+
return doGetColumn(AggregatePath.ColumnInfos::toColumnList);
1015+
}
1016+
1017+
private <T> T doGetColumn(
1018+
BiFunction<AggregatePath.ColumnInfos, BiFunction<AggregatePath, AggregatePath.ColumnInfo, Column>, T> columnListFunction) {
9911019

9921020
AggregatePath.ColumnInfos columnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
9931021

994-
// sqlcontext.getColumn (vs sqlContext.getTable
995-
return columnInfos
996-
.toColumnList((aggregatePath, columnInfo) -> sqlContext.getColumn(aggregatePath));
1022+
return columnListFunction.apply(columnInfos, (aggregatePath, columnInfo) -> sqlContext.getColumn(aggregatePath));
9971023
}
9981024

9991025
private Column getVersionColumn() {
@@ -1164,7 +1190,7 @@ private SelectBuilder.SelectJoin getExistsSelect() {
11641190
}
11651191
}
11661192

1167-
return addJoins(baseSelect, joins);
1193+
return addJoins(baseSelect, Joins.of(joins));
11681194
}
11691195

11701196
/**
@@ -1199,7 +1225,7 @@ private SelectBuilder.SelectJoin getSelectCountWithExpression(Expression... coun
11991225
joins.add(join);
12001226
}
12011227
}
1202-
return addJoins(baseSelect, joins);
1228+
return addJoins(baseSelect, Joins.of(joins));
12031229
}
12041230

12051231
private SelectBuilder.SelectOrdered applyQueryOnSelect(Query query, MapSqlParameterSource parameterSource,

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
3434
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
3535
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
36+
import org.springframework.data.relational.core.mapping.RelationalPredicates;
3637
import org.springframework.data.relational.core.sql.SqlIdentifier;
3738
import org.springframework.lang.Nullable;
3839

@@ -87,7 +88,7 @@ <T> SqlIdentifierParameterSource forInsert(T instance, Class<T> domainType, Iden
8788
AggregatePath.ColumnInfos columnInfos = context.getAggregatePath(persistentEntity).getTableInfo().idColumnInfos();
8889

8990
// fullPath: because we use the result with a PropertyPathAccessor
90-
columnInfos.forEachLong((ap, __) -> {
91+
columnInfos.forEach((ap, __) -> {
9192
Object idValue = propertyPathAccessor.getProperty(ap.getRequiredPersistentPropertyPath());
9293
RelationalPersistentProperty idProperty = ap.getRequiredLeafProperty();
9394
addConvertedPropertyValue(parameterSource, idProperty, idValue, idProperty.getColumnName());
@@ -259,12 +260,14 @@ private <S, T> SqlIdentifierParameterSource getParameterSource(@Nullable S insta
259260
PersistentPropertyAccessor<S> propertyAccessor = instance != null ? persistentEntity.getPropertyAccessor(instance)
260261
: NoValuePropertyAccessor.instance();
261262

263+
262264
persistentEntity.doWithAll(property -> {
263265

264266
if (skipProperty.test(property) || !property.isWritable()) {
265267
return;
266268
}
267-
if (property.isEntity() && !property.isEmbedded()) {
269+
270+
if (RelationalPredicates.isRelation(property)) {
268271
return;
269272
}
270273

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/mapping/schema/Tables.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
3434
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
3535
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
36+
import org.springframework.data.relational.core.mapping.RelationalPredicates;
3637
import org.springframework.data.relational.core.sql.SqlIdentifier;
3738
import org.springframework.lang.Nullable;
3839
import org.springframework.util.Assert;
@@ -68,7 +69,7 @@ public static Tables from(Stream<? extends RelationalPersistentEntity<?>> persis
6869

6970
for (RelationalPersistentProperty property : entity) {
7071

71-
if (property.isEntity() && !property.isEmbedded()) {
72+
if (RelationalPredicates.isRelation(property)) {
7273
foreignKeyMetadataList.add(createForeignKeyMetadata(entity, property, context, sqlTypeMapping));
7374
continue;
7475
}

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/CompositeIdAggregateTemplateHsqlIntegrationTests.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ void sortByCompositeIdParts() {
237237
@Test // GH-574
238238
void projectByCompositeIdParts() {
239239

240-
SimpleEntityWithEmbeddedPk alpha = template.insert( //
240+
template.insert( //
241241
new SimpleEntityWithEmbeddedPk( //
242242
new EmbeddedPk(23L, "x"), "alpha" //
243243
));
@@ -246,16 +246,19 @@ void projectByCompositeIdParts() {
246246
SimpleEntityWithEmbeddedPk projected = template.findOne(projectingQuery, SimpleEntityWithEmbeddedPk.class)
247247
.orElseThrow();
248248

249-
// Projection still does a full select, otherwise one would be null.
250-
// See https://github.com/spring-projects/spring-data-relational/issues/1821
249+
assertThat(projected).isEqualTo(new SimpleEntityWithEmbeddedPk(new EmbeddedPk(null, "x"), "alpha"));
250+
251+
projectingQuery = Query.empty().columns("embeddedPk", "name");
252+
projected = template.findOne(projectingQuery, SimpleEntityWithEmbeddedPk.class).orElseThrow();
253+
251254
assertThat(projected).isEqualTo(new SimpleEntityWithEmbeddedPk(new EmbeddedPk(23L, "x"), "alpha"));
252255
}
253256

254257
private record WrappedPk(Long id) {
255258
}
256259

257260
private record SimpleEntity( //
258-
@Id @Embedded(onEmpty = Embedded.OnEmpty.USE_NULL) WrappedPk wrappedPk, //
261+
@Id WrappedPk wrappedPk, //
259262
String name //
260263
) {
261264
}
@@ -272,7 +275,7 @@ private record EmbeddedPk(Long one, String two) {
272275
}
273276

274277
private record SimpleEntityWithEmbeddedPk( //
275-
@Id @Embedded(onEmpty = Embedded.OnEmpty.USE_NULL) EmbeddedPk embeddedPk, //
278+
@Id EmbeddedPk embeddedPk, //
276279
String name //
277280
) {
278281
}

0 commit comments

Comments
 (0)