21
21
import reactor .core .publisher .Flux ;
22
22
import reactor .core .publisher .Mono ;
23
23
24
- import java .beans .FeatureDescriptor ;
25
24
import java .util .Collections ;
25
+ import java .util .LinkedHashSet ;
26
26
import java .util .List ;
27
27
import java .util .Map ;
28
28
import java .util .Optional ;
29
+ import java .util .Set ;
29
30
import java .util .function .BiFunction ;
30
31
import java .util .function .Function ;
31
32
import java .util .stream .Collectors ;
32
33
33
34
import org .reactivestreams .Publisher ;
35
+
34
36
import org .springframework .beans .BeansException ;
35
37
import org .springframework .beans .factory .BeanFactory ;
36
38
import org .springframework .beans .factory .BeanFactoryAware ;
46
48
import org .springframework .data .mapping .callback .ReactiveEntityCallbacks ;
47
49
import org .springframework .data .mapping .context .MappingContext ;
48
50
import org .springframework .data .projection .EntityProjection ;
49
- import org .springframework .data .projection .ProjectionInformation ;
50
51
import org .springframework .data .projection .SpelAwareProxyProjectionFactory ;
51
52
import org .springframework .data .r2dbc .convert .R2dbcConverter ;
52
53
import org .springframework .data .r2dbc .dialect .DialectResolver ;
56
57
import org .springframework .data .r2dbc .mapping .event .AfterSaveCallback ;
57
58
import org .springframework .data .r2dbc .mapping .event .BeforeConvertCallback ;
58
59
import org .springframework .data .r2dbc .mapping .event .BeforeSaveCallback ;
60
+ import org .springframework .data .relational .core .mapping .PersistentPropertyTranslator ;
59
61
import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
60
62
import org .springframework .data .relational .core .mapping .RelationalPersistentProperty ;
61
63
import org .springframework .data .relational .core .query .Criteria ;
68
70
import org .springframework .data .relational .core .sql .SqlIdentifier ;
69
71
import org .springframework .data .relational .core .sql .Table ;
70
72
import org .springframework .data .relational .domain .RowDocument ;
73
+ import org .springframework .data .util .Predicates ;
71
74
import org .springframework .data .util .ProxyUtils ;
72
75
import org .springframework .lang .Nullable ;
73
76
import org .springframework .r2dbc .core .DatabaseClient ;
@@ -332,7 +335,7 @@ private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdent
332
335
333
336
StatementMapper .SelectSpec selectSpec = statementMapper //
334
337
.createSelect (tableName ) //
335
- .doWithTable ((table , spec ) -> spec .withProjection (getSelectProjection (table , query , returnType )));
338
+ .doWithTable ((table , spec ) -> spec .withProjection (getSelectProjection (table , query , entityType , returnType )));
336
339
337
340
if (query .getLimit () > 0 ) {
338
341
selectSpec = selectSpec .limit (query .getLimit ());
@@ -423,7 +426,8 @@ public <T> RowsFetchSpec<T> query(PreparedOperation<?> operation, Class<T> entit
423
426
}
424
427
425
428
@ Override
426
- public <T > RowsFetchSpec <T > query (PreparedOperation <?> operation , Class <?> entityClass , Class <T > resultType ) throws DataAccessException {
429
+ public <T > RowsFetchSpec <T > query (PreparedOperation <?> operation , Class <?> entityClass , Class <T > resultType )
430
+ throws DataAccessException {
427
431
428
432
Assert .notNull (operation , "PreparedOperation must not be null" );
429
433
Assert .notNull (entityClass , "Entity class must not be null" );
@@ -759,18 +763,16 @@ private <T> RelationalPersistentEntity<T> getRequiredEntity(T entity) {
759
763
return (RelationalPersistentEntity ) getRequiredEntity (entityType );
760
764
}
761
765
762
- private <T > List <Expression > getSelectProjection (Table table , Query query , Class <T > returnType ) {
766
+ private <T > List <Expression > getSelectProjection (Table table , Query query , Class <?> entityType , Class < T > returnType ) {
763
767
764
768
if (query .getColumns ().isEmpty ()) {
765
769
766
- if (returnType .isInterface ()) {
770
+ EntityProjection <T , ?> projection = converter .introspectProjection (returnType , entityType );
771
+
772
+ if (projection .isProjection () && projection .isClosedProjection ()) {
767
773
768
- ProjectionInformation projectionInformation = projectionFactory . getProjectionInformation ( returnType );
774
+ return computeProjectedFields ( table , returnType , projection );
769
775
770
- if (projectionInformation .isClosed ()) {
771
- return projectionInformation .getInputProperties ().stream ().map (FeatureDescriptor ::getName ).map (table ::column )
772
- .collect (Collectors .toList ());
773
- }
774
776
}
775
777
776
778
return Collections .singletonList (table .asterisk ());
@@ -779,6 +781,36 @@ private <T> List<Expression> getSelectProjection(Table table, Query query, Class
779
781
return query .getColumns ().stream ().map (table ::column ).collect (Collectors .toList ());
780
782
}
781
783
784
+ @ SuppressWarnings ("unchecked" )
785
+ private <T > List <Expression > computeProjectedFields (Table table , Class <T > returnType ,
786
+ EntityProjection <T , ?> projection ) {
787
+
788
+ if (returnType .isInterface ()) {
789
+
790
+ Set <String > properties = new LinkedHashSet <>();
791
+ projection .forEach (it -> {
792
+ properties .add (it .getPropertyPath ().getSegment ());
793
+ });
794
+
795
+ return properties .stream ().map (table ::column ).collect (Collectors .toList ());
796
+ }
797
+
798
+ Set <SqlIdentifier > properties = new LinkedHashSet <>();
799
+ // DTO projections use merged metadata between domain type and result type
800
+ PersistentPropertyTranslator translator = PersistentPropertyTranslator .create (
801
+ mappingContext .getRequiredPersistentEntity (projection .getDomainType ()),
802
+ Predicates .negate (RelationalPersistentProperty ::hasExplicitColumnName ));
803
+
804
+ RelationalPersistentEntity <?> persistentEntity = mappingContext
805
+ .getRequiredPersistentEntity (projection .getMappedType ());
806
+ for (RelationalPersistentProperty property : persistentEntity ) {
807
+ properties .add (translator .translate (property ).getColumnName ());
808
+ }
809
+
810
+ return properties .stream ().map (table ::column ).collect (Collectors .toList ());
811
+ }
812
+
813
+ @ SuppressWarnings ("unchecked" )
782
814
public <T > RowsFetchSpec <T > getRowsFetchSpec (DatabaseClient .GenericExecuteSpec executeSpec , Class <?> entityType ,
783
815
Class <T > resultType ) {
784
816
@@ -791,13 +823,13 @@ public <T> RowsFetchSpec<T> getRowsFetchSpec(DatabaseClient.GenericExecuteSpec e
791
823
} else {
792
824
793
825
EntityProjection <T , ?> projection = converter .introspectProjection (resultType , entityType );
826
+ Class <T > typeToRead = projection .isProjection () ? resultType
827
+ : resultType .isInterface () ? (Class <T >) entityType : resultType ;
794
828
795
829
rowMapper = (row , rowMetadata ) -> {
796
830
797
- RowDocument document = dataAccessStrategy .toRowDocument (resultType , row , rowMetadata .getColumnMetadatas ());
798
-
799
- return projection .isProjection () ? converter .project (projection , document )
800
- : converter .read (resultType , document );
831
+ RowDocument document = dataAccessStrategy .toRowDocument (typeToRead , row , rowMetadata .getColumnMetadatas ());
832
+ return converter .project (projection , document );
801
833
};
802
834
}
803
835
0 commit comments