diff --git a/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java b/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java index 325da40e8da..da6a59d57dc 100644 --- a/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java +++ b/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java @@ -15,7 +15,9 @@ */ package org.apache.ibatis.builder.annotation; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.lang.reflect.Modifier; import java.util.HashMap; import java.util.Map; @@ -109,16 +111,16 @@ private SqlSource createSqlSource(Object parameterObject) { int bindParameterCount = providerMethodParameterTypes.length - (providerContext == null ? 0 : 1); String sql; if (providerMethodParameterTypes.length == 0) { - sql = (String) providerMethod.invoke(providerType.newInstance()); + sql = invokeProviderMethod(); } else if (bindParameterCount == 0) { - sql = (String) providerMethod.invoke(providerType.newInstance(), providerContext); + sql = invokeProviderMethod(providerContext); } else if (bindParameterCount == 1 && (parameterObject == null || providerMethodParameterTypes[(providerContextIndex == null || providerContextIndex == 1) ? 0 : 1].isAssignableFrom(parameterObject.getClass()))) { - sql = (String) providerMethod.invoke(providerType.newInstance(), extractProviderMethodArguments(parameterObject)); + sql = invokeProviderMethod(extractProviderMethodArguments(parameterObject)); } else if (parameterObject instanceof Map) { @SuppressWarnings("unchecked") Map params = (Map) parameterObject; - sql = (String) providerMethod.invoke(providerType.newInstance(), extractProviderMethodArguments(params, providerMethodArgumentNames)); + sql = invokeProviderMethod(extractProviderMethodArguments(params, providerMethodArgumentNames)); } else { throw new BuilderException("Error invoking SqlProvider method (" + providerType.getName() + "." + providerMethod.getName() @@ -160,6 +162,14 @@ private Object[] extractProviderMethodArguments(Map params, Stri return args; } + private String invokeProviderMethod(Object... args) throws Exception { + Object targetObject = null; + if (!Modifier.isStatic(providerMethod.getModifiers())) { + targetObject = providerType.newInstance(); + } + return (String) providerMethod.invoke(targetObject, args); + } + private String replacePlaceholder(String sql) { return PropertyParser.parse(sql, configuration.getVariables()); } diff --git a/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java b/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java index f28434422b2..857d082c2b0 100644 --- a/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java +++ b/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; +import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.SelectProvider; import org.apache.ibatis.builder.BuilderException; import org.apache.ibatis.builder.annotation.ProviderContext; @@ -52,6 +53,7 @@ public static void setUp() throws Exception { Reader reader = Resources .getResourceAsReader("org/apache/ibatis/submitted/sqlprovider/mybatis-config.xml"); sqlSessionFactory = new SqlSessionFactoryBuilder().build(reader); + sqlSessionFactory.getConfiguration().addMapper(StaticMethodSqlProviderMapper.class); reader.close(); // populate in-memory database @@ -472,6 +474,66 @@ public void mapperMultipleParamAndProviderContext() { } } + @Test + public void staticMethodNoArgument() { + SqlSession sqlSession = sqlSessionFactory.openSession(); + try { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(1, mapper.noArgument()); + } finally { + sqlSession.close(); + } + } + + @Test + public void staticMethodOneArgument() { + SqlSession sqlSession = sqlSessionFactory.openSession(); + try { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10, mapper.oneArgument(10)); + } finally { + sqlSession.close(); + } + } + + @Test + public void staticMethodMultipleArgument() { + SqlSession sqlSession = sqlSessionFactory.openSession(); + try { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(2, mapper.multipleArgument(1, 1)); + } finally { + sqlSession.close(); + } + } + + @Test + public void staticMethodOnlyProviderContext() { + SqlSession sqlSession = sqlSessionFactory.openSession(); + try { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals("onlyProviderContext", mapper.onlyProviderContext()); + } finally { + sqlSession.close(); + } + } + + @Test + public void staticMethodOneArgumentAndProviderContext() { + SqlSession sqlSession = sqlSessionFactory.openSession(); + try { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals("oneArgumentAndProviderContext 100", mapper.oneArgumentAndProviderContext(100)); + } finally { + sqlSession.close(); + } + } + public interface ErrorMapper { @SelectProvider(type = ErrorSqlBuilder.class, method = "methodNotFound") void methodNotFound(); @@ -508,4 +570,48 @@ public String multipleProviderContext(ProviderContext providerContext1, Provider } } + public interface StaticMethodSqlProviderMapper { + @SelectProvider(type = SqlProvider.class, method = "noArgument") + int noArgument(); + + @SelectProvider(type = SqlProvider.class, method = "oneArgument") + int oneArgument(Integer value); + + @SelectProvider(type = SqlProvider.class, method = "multipleArgument") + int multipleArgument(@Param("value1") Integer value1, @Param("value2") Integer value2); + + @SelectProvider(type = SqlProvider.class, method = "onlyProviderContext") + String onlyProviderContext(); + + @SelectProvider(type = SqlProvider.class, method = "oneArgumentAndProviderContext") + String oneArgumentAndProviderContext(Integer value); + + class SqlProvider { + public static String noArgument() { + return "SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS"; + } + + public static String oneArgument(Integer value) { + return "SELECT " + value + " FROM INFORMATION_SCHEMA.SYSTEM_USERS"; + } + + public static String multipleArgument(@Param("value1") Integer value1, + @Param("value2") Integer value2) { + return "SELECT " + (value1 + value2) + " FROM INFORMATION_SCHEMA.SYSTEM_USERS"; + } + + public static String onlyProviderContext(ProviderContext context) { + return "SELECT '" + context.getMapperMethod().getName() + + "' FROM INFORMATION_SCHEMA.SYSTEM_USERS"; + } + + public static String oneArgumentAndProviderContext(Integer value, ProviderContext context) { + return "SELECT '" + context.getMapperMethod().getName() + " " + value + + "' FROM INFORMATION_SCHEMA.SYSTEM_USERS"; + } + + } + + } + }