34
34
import org .springframework .core .annotation .MergedAnnotation ;
35
35
import org .springframework .data .domain .SliceImpl ;
36
36
import org .springframework .data .domain .Sort ;
37
+ import org .springframework .data .javapoet .LordOfTheStrings ;
37
38
import org .springframework .data .jdbc .repository .query .JdbcQueryMethod ;
38
39
import org .springframework .data .jdbc .repository .query .Modifying ;
39
40
import org .springframework .data .jdbc .repository .query .ParameterBinding ;
44
45
import org .springframework .data .relational .core .sql .LockMode ;
45
46
import org .springframework .data .relational .repository .Lock ;
46
47
import org .springframework .data .repository .aot .generate .AotQueryMethodGenerationContext ;
48
+ import org .springframework .data .repository .aot .generate .MethodReturn ;
47
49
import org .springframework .data .repository .query .parser .Part ;
48
50
import org .springframework .data .support .PageableExecutionUtils ;
49
51
import org .springframework .data .util .Pair ;
50
- import org .springframework .data .util .ReflectionUtils ;
51
52
import org .springframework .javapoet .CodeBlock ;
52
53
import org .springframework .javapoet .CodeBlock .Builder ;
53
54
import org .springframework .javapoet .TypeName ;
58
59
import org .springframework .jdbc .core .namedparam .MapSqlParameterSource ;
59
60
import org .springframework .jdbc .core .namedparam .SqlParameterSource ;
60
61
import org .springframework .util .Assert ;
61
- import org .springframework .util .ObjectUtils ;
62
+ import org .springframework .util .ClassUtils ;
62
63
import org .springframework .util .StringUtils ;
63
64
64
65
/**
@@ -568,25 +569,26 @@ public CodeBlock build() {
568
569
569
570
Builder builder = CodeBlock .builder ();
570
571
571
- boolean isProjecting = !ObjectUtils .nullSafeEquals (
572
- TypeName .get (context .getRepositoryInformation ().getDomainType ()), context .getActualReturnType ());
573
- Type actualReturnType = isProjecting ? context .getActualReturnType ().getType ()
572
+ MethodReturn methodReturn = context .getMethodReturn ();
573
+ boolean isProjecting = methodReturn .isProjecting ()
574
+ || StringUtils .hasText (context .getDynamicProjectionParameterName ());
575
+ Type actualReturnType = isProjecting ? methodReturn .getActualReturnClass ()
574
576
: context .getRepositoryInformation ().getDomainType ();
575
- builder .add ("\n " );
576
577
577
- Class <?> returnType = context .getMethod ().getReturnType ();
578
- TypeName queryResultType = TypeName .get (context .getActualReturnType ().toClass ());
578
+ Class <?> returnType = context .getMethodReturn ().toClass ();
579
+
580
+ TypeName queryResultType = methodReturn .getActualClassName ();
579
581
String result = context .localVariable ("result" );
580
582
String rowMapper = context .localVariable ("rowMapper" );
581
583
582
584
if (modifying .isPresent ()) {
583
- return update (builder , returnType );
585
+ return update (returnType );
584
586
} else if (aotQuery .isCount ()) {
585
- return count (builder , result , returnType , queryResultType );
587
+ return count (result , returnType , queryResultType );
586
588
} else if (aotQuery .isExists ()) {
587
- return exists (builder , queryResultType );
589
+ return exists (queryResultType );
588
590
} else if (aotQuery .isDelete ()) {
589
- return delete (builder , rowMapper , result , queryResultType , returnType , actualReturnType );
591
+ return delete (rowMapper , result , queryResultType , returnType , actualReturnType );
590
592
} else {
591
593
592
594
String resultSetExtractor = null ;
@@ -603,7 +605,7 @@ public CodeBlock build() {
603
605
if (isProjecting ) {
604
606
typeToRead = context .getReturnedType ().getDomainType ();
605
607
} else {
606
- typeToRead = context . getActualReturnType (). getType ();
608
+ typeToRead = methodReturn . getActualReturnClass ();
607
609
}
608
610
609
611
builder .addStatement ("$T $L = getRowMapperFactory().create($T.class)" , RowMapper .class , rowMapper ,
@@ -667,63 +669,61 @@ public CodeBlock build() {
667
669
}
668
670
669
671
builder .addStatement ("return ($T) convertMany($L, %s)" .formatted (dynamicProjection ? "$L" : "$T.class" ),
670
- context . getReturnTypeName (), result , queryResultTypeRef );
672
+ methodReturn . getTypeName (), result , queryResultTypeRef );
671
673
} else if (queryMethod .isStreamQuery ()) {
672
674
673
675
builder .addStatement ("$1T $2L = getJdbcOperations().queryForStream($3L, $4L, $5L)" , Stream .class , result ,
674
676
queryVariableName , parameterSourceVariableName , rowMapper );
675
- builder .addStatement ("return ($T) convertMany($L, $T.class)" , context . getReturnTypeName (), result ,
677
+ builder .addStatement ("return ($T) convertMany($L, $T.class)" , methodReturn . getTypeName (), result ,
676
678
queryResultTypeRef );
677
679
} else {
678
680
679
681
builder .addStatement ("$T $L = queryForObject($L, $L, $L)" , Object .class , result , queryVariableName ,
680
682
parameterSourceVariableName , rowMapper );
681
683
682
- if (Optional . class . isAssignableFrom ( context . getReturnType (). toClass () )) {
684
+ if (methodReturn . isOptional ( )) {
683
685
builder .addStatement (
684
686
"return ($1T) $1T.ofNullable(convertOne($2L, %s))" .formatted (dynamicProjection ? "$3L" : "$3T.class" ),
685
687
Optional .class , result , queryResultTypeRef );
686
688
} else {
687
689
builder .addStatement ("return ($T) convertOne($L, %s)" .formatted (dynamicProjection ? "$L" : "$T.class" ),
688
- context . getReturnTypeName (), result , queryResultTypeRef );
690
+ methodReturn . getTypeName (), result , queryResultTypeRef );
689
691
}
690
692
}
691
693
}
692
694
693
695
return builder .build ();
694
696
}
695
697
696
- private CodeBlock update (Builder builder , Class <?> returnType ) {
698
+ private CodeBlock update (Class <?> returnType ) {
697
699
698
700
String result = context .localVariable ("result" );
699
701
700
- builder .add ("$[" );
701
-
702
- if (!ReflectionUtils .isVoid (returnType )) {
703
- builder .add ("int $L = " , result );
704
- }
702
+ Builder builder = CodeBlock .builder ();
705
703
706
- builder . add ("getJdbcOperations().update($L, $L)" , queryVariableName , parameterSourceVariableName );
707
- builder . add ( "; \n $]" );
704
+ LordOfTheStrings . InvocationBuilder invoke = LordOfTheStrings . invoke ("getJdbcOperations().update($L, $L)" ,
705
+ queryVariableName , parameterSourceVariableName );
708
706
709
- if (returnType == boolean .class || returnType == Boolean .class ) {
710
- builder .addStatement ("return $L != 0" , result );
711
- } else if (returnType == Long .class ) {
712
- builder .addStatement ("return (long) $L" , result );
713
- } else if (ReflectionUtils .isVoid (returnType )) {
714
- if (returnType == Void .class ) {
715
- builder .addStatement ("return null" );
716
- }
707
+ if (context .getMethodReturn ().isVoid ()) {
708
+ builder .addStatement (invoke .build ());
717
709
} else {
718
- builder .addStatement ("return $L" , result );
710
+ builder .addStatement (invoke . assignTo ( "int $L" , result ) );
719
711
}
720
712
713
+ builder .addStatement (LordOfTheStrings .returning (returnType ) //
714
+ .whenBoolean ("$L != 0" , result ) //
715
+ .whenBoxedLong ("(long) $L" , result ) //
716
+ .otherwise ("$L" , result )//
717
+ .build ());
718
+
721
719
return builder .build ();
722
720
}
723
721
724
- private CodeBlock delete (Builder builder , String rowMapper , String result , TypeName queryResultType ,
722
+ private CodeBlock delete (String rowMapper , String result , TypeName queryResultType ,
725
723
Class <?> returnType , Type actualReturnType ) {
726
724
725
+ CodeBlock .Builder builder = CodeBlock .builder ();
726
+
727
727
builder .addStatement ("$T $L = getRowMapperFactory().create($T.class)" , RowMapper .class , rowMapper ,
728
728
context .getRepositoryInformation ().getDomainType ());
729
729
@@ -732,48 +732,37 @@ private CodeBlock delete(Builder builder, String rowMapper, String result, TypeN
732
732
733
733
builder .addStatement ("$L.forEach(getOperations()::delete)" , result );
734
734
735
- if (Collection .class .isAssignableFrom (context .getReturnType ().toClass ())) {
736
- builder .addStatement ("return ($T) convertMany($L, $T.class)" , context .getReturnTypeName (), result ,
737
- queryResultType );
738
- } else if (returnType == context .getRepositoryInformation ().getDomainType ()) {
739
- builder .addStatement ("return ($1T) ($2L.isEmpty() ? null : $2L.iterator().next())" , actualReturnType , result );
740
- } else if (returnType == boolean .class || returnType == Boolean .class ) {
741
- builder .addStatement ("return !$L.isEmpty()" , result );
742
- } else if (returnType == Long .class ) {
743
- builder .addStatement ("return (long) $L.size()" , result );
744
- } else if (ReflectionUtils .isVoid (returnType )) {
745
- if (returnType == Void .class ) {
746
- builder .addStatement ("return null" );
747
- }
748
- } else {
749
- builder .addStatement ("return $L.size()" , result );
750
- }
735
+ builder .addStatement (LordOfTheStrings .returning (returnType ) //
736
+ .when (Collection .class .isAssignableFrom (context .getMethodReturn ().toClass ()),
737
+ "($T) convertMany($L, $T.class)" , context .getMethodReturn ().getTypeName (), result , queryResultType ) //
738
+ .when (context .getRepositoryInformation ().getDomainType (),
739
+ "($1T) ($2L.isEmpty() ? null : $2L.iterator().next())" , actualReturnType , result ) //
740
+ .whenBoolean ("!$L.isEmpty()" , result ) //
741
+ .whenBoxedLong ("(long) $L.size()" , result ) //
742
+ .otherwise ("$L.size()" , result ) //
743
+ .build ());
751
744
752
745
return builder .build ();
753
746
}
754
747
755
- private CodeBlock count (Builder builder , String result , Class <?> returnType , TypeName queryResultType ) {
748
+ private CodeBlock count (String result , Class <?> returnType , TypeName queryResultType ) {
749
+
750
+ CodeBlock .Builder builder = CodeBlock .builder ();
756
751
757
752
builder .addStatement ("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))" , Number .class , result ,
758
753
queryVariableName , parameterSourceVariableName , SingleColumnRowMapper .class );
759
754
760
- if (returnType == Long .class ) {
761
- builder .addStatement ("return $1L != null ? $1L.longValue() : null" , result );
762
- } else if (returnType == Integer .class ) {
763
- builder .addStatement ("return $1L != null ? $1L.intValue() : null" , result );
764
- } else if (returnType == Long .TYPE ) {
765
- builder .addStatement ("return $1L != null ? $1L.longValue() : 0L" , result );
766
- } else if (returnType == Integer .TYPE ) {
767
- builder .addStatement ("return $1L != null ? $1L.intValue() : 0" , result );
768
- } else {
769
- builder .addStatement ("return ($T) convertOne($L, $T.class)" , context .getReturnTypeName (), result ,
770
- queryResultType );
771
- }
755
+ builder .addStatement (LordOfTheStrings .returning (returnType ) //
756
+ .number (result ) //
757
+ .otherwise ("($T) convertOne($L, $T.class)" , context .getMethodReturn ().getTypeName (), result , queryResultType ) //
758
+ .build ());
772
759
773
760
return builder .build ();
774
761
}
775
762
776
- private CodeBlock exists (Builder builder , TypeName queryResultType ) {
763
+ private CodeBlock exists (TypeName queryResultType ) {
764
+
765
+ CodeBlock .Builder builder = CodeBlock .builder ();
777
766
778
767
builder .addStatement ("return ($T) getJdbcOperations().query($L, $L, $T::next)" , queryResultType ,
779
768
queryVariableName , parameterSourceVariableName , ResultSet .class );
@@ -783,8 +772,8 @@ private CodeBlock exists(Builder builder, TypeName queryResultType) {
783
772
784
773
public static boolean returnsModifying (Class <?> returnType ) {
785
774
786
- return returnType == int . class || returnType == long . class || returnType == Integer .class
787
- || returnType == Long .class ;
775
+ return ClassUtils . resolvePrimitiveIfNecessary ( returnType ) == Integer .class
776
+ || ClassUtils . resolvePrimitiveIfNecessary ( returnType ) == Long .class ;
788
777
}
789
778
790
779
}
0 commit comments