diff --git a/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java b/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java index 6a1ecab9508..b3a3eae8185 100644 --- a/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java +++ b/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java @@ -15,12 +15,6 @@ */ package org.apache.ibatis.builder.annotation; -import java.lang.annotation.Annotation; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.Map; - import org.apache.ibatis.annotations.Lang; import org.apache.ibatis.builder.BuilderException; import org.apache.ibatis.mapping.BoundSql; @@ -28,6 +22,12 @@ import org.apache.ibatis.reflection.ParamNameResolver; import org.apache.ibatis.scripting.LanguageDriver; import org.apache.ibatis.session.Configuration; +import java.lang.annotation.Annotation; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.HashMap; +import java.util.Map; /** * @author Clinton Begin @@ -44,6 +44,28 @@ public class ProviderSqlSource implements SqlSource { private ProviderContext providerContext; private Integer providerContextIndex; + private static final Map, Class> primitiveWrapperMap = new HashMap<>(); + static { + primitiveWrapperMap.put(byte.class, Byte.class); + primitiveWrapperMap.put(short.class, Short.class); + primitiveWrapperMap.put(int.class, Integer.class); + primitiveWrapperMap.put(long.class, Long.class); + primitiveWrapperMap.put(float.class, Float.class); + primitiveWrapperMap.put(double.class, Double.class); + primitiveWrapperMap.put(boolean.class, Boolean.class); + primitiveWrapperMap.put(char.class, Character.class); + } + + private static boolean isAssignableFrom(Class to, Class from) { + if (to.isAssignableFrom(from)) { + return true; + } + if (to.isPrimitive()) { + return from == primitiveWrapperMap.get(to); + } + return false; + } + /** * @deprecated Please use the {@link #ProviderSqlSource(Configuration, Object, Class, Method)} instead of this. */ @@ -121,7 +143,7 @@ private SqlSource createSqlSource(Object parameterObject) { } else if (bindParameterCount == 0) { sql = invokeProviderMethod(providerContext); } else if (bindParameterCount == 1 - && (parameterObject == null || providerMethodParameterTypes[providerContextIndex == null || providerContextIndex == 1 ? 0 : 1].isAssignableFrom(parameterObject.getClass()))) { + && (parameterObject == null || isAssignableFrom(providerMethodParameterTypes[providerContextIndex == null || providerContextIndex == 1 ? 0 : 1], parameterObject.getClass()))) { sql = invokeProviderMethod(extractProviderMethodArguments(parameterObject)); } else if (parameterObject instanceof Map) { @SuppressWarnings("unchecked") diff --git a/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java b/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java index 0a2720a86d3..3d3862f861e 100644 --- a/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java +++ b/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java @@ -464,6 +464,78 @@ void staticMethodOneArgument() { } } + @Test + void staticMethodOnePrimitiveByteArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals((byte) 10, mapper.onePrimitiveByteArgument((byte) 10)); + } + } + + @Test + void staticMethodOnePrimitiveShortArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals((short) 10, mapper.onePrimitiveShortArgument((short) 10)); + } + } + + @Test + void staticMethodOnePrimitiveIntArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10, mapper.onePrimitiveIntArgument(10)); + } + } + + @Test + void staticMethodOnePrimitiveLongArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10L, mapper.onePrimitiveLongArgument(10L)); + } + } + + @Test + void staticMethodOnePrimitiveFloatArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10.1F, mapper.onePrimitiveFloatArgument(10.1F)); + } + } + + @Test + void staticMethodOnePrimitiveDoubleArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10.1D, mapper.onePrimitiveDoubleArgument(10.1D)); + } + } + + @Test + void staticMethodOnePrimitiveBooleanArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertTrue(mapper.onePrimitiveBooleanArgument(true)); + } + } + + @Test + void staticMethodOnePrimitiveCharArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals('A', mapper.onePrimitiveCharArgument('A')); + } + } + @Test void staticMethodMultipleArgument() { try (SqlSession sqlSession = sqlSessionFactory.openSession()) { @@ -541,6 +613,30 @@ public interface StaticMethodSqlProviderMapper { @SelectProvider(type = SqlProvider.class, method = "oneArgument") int oneArgument(Integer value); + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveByteArgument") + byte onePrimitiveByteArgument(byte value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveShortArgument") + short onePrimitiveShortArgument(short value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveIntArgument") + int onePrimitiveIntArgument(int value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveLongArgument") + long onePrimitiveLongArgument(long value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveFloatArgument") + float onePrimitiveFloatArgument(float value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveDoubleArgument") + double onePrimitiveDoubleArgument(double value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveBooleanArgument") + boolean onePrimitiveBooleanArgument(boolean value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveCharArgument") + char onePrimitiveCharArgument(char value); + @SelectProvider(type = SqlProvider.class, method = "multipleArgument") int multipleArgument(@Param("value1") Integer value1, @Param("value2") Integer value2); @@ -562,6 +658,46 @@ public static StringBuilder oneArgument(Integer value) { .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); } + public static StringBuilder onePrimitiveByteArgument(byte value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveShortArgument(short value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveIntArgument(int value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveLongArgument(long value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveFloatArgument(float value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveDoubleArgument(double value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveBooleanArgument(boolean value) { + return new StringBuilder().append("SELECT ").append(value ? 1 : 0) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveCharArgument(char value) { + return new StringBuilder().append("SELECT '").append(value) + .append("' FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + public static CharSequence multipleArgument(@Param("value1") Integer value1, @Param("value2") Integer value2) { return "SELECT " + (value1 + value2) + " FROM INFORMATION_SCHEMA.SYSTEM_USERS";