diff --git a/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java b/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java index 6a1ecab9508..0b3b1fa9744 100644 --- a/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java +++ b/src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java @@ -38,6 +38,7 @@ public class ProviderSqlSource implements SqlSource { private final Configuration configuration; private final Class providerType; private final LanguageDriver languageDriver; + private final Method mapperMethod; private Method providerMethod; private String[] providerMethodArgumentNames; private Class[] providerMethodParameterTypes; @@ -59,6 +60,7 @@ public ProviderSqlSource(Configuration configuration, Object provider, Class String providerMethodName; try { this.configuration = configuration; + this.mapperMethod = mapperMethod; Lang lang = mapperMethod == null ? null : mapperMethod.getAnnotation(Lang.class); this.languageDriver = configuration.getLanguageDriver(lang == null ? null : lang.value()); this.providerType = getProviderType(provider, mapperMethod); @@ -116,33 +118,45 @@ private SqlSource createSqlSource(Object parameterObject) { try { int bindParameterCount = providerMethodParameterTypes.length - (providerContext == null ? 0 : 1); String sql; - if (providerMethodParameterTypes.length == 0) { + if (parameterObject instanceof Map) { + if (bindParameterCount == 1 && providerMethodParameterTypes[0] == Map.class) { + sql = invokeProviderMethod(extractProviderMethodArguments(parameterObject)); + } else { + @SuppressWarnings("unchecked") + Map params = (Map) parameterObject; + sql = invokeProviderMethod(extractProviderMethodArguments(params, providerMethodArgumentNames)); + } + } else if (providerMethodParameterTypes.length == 0) { sql = invokeProviderMethod(); - } else if (bindParameterCount == 0) { - sql = invokeProviderMethod(providerContext); - } else if (bindParameterCount == 1 - && (parameterObject == null || providerMethodParameterTypes[providerContextIndex == null || providerContextIndex == 1 ? 0 : 1].isAssignableFrom(parameterObject.getClass()))) { + } else if (providerMethodParameterTypes.length == 1) { + if (providerContext == null) { + sql = invokeProviderMethod(parameterObject); + } else { + sql = invokeProviderMethod(providerContext); + } + } else if (providerMethodParameterTypes.length == 2) { sql = invokeProviderMethod(extractProviderMethodArguments(parameterObject)); - } else if (parameterObject instanceof Map) { - @SuppressWarnings("unchecked") - Map params = (Map) parameterObject; - sql = invokeProviderMethod(extractProviderMethodArguments(params, providerMethodArgumentNames)); } else { - throw new BuilderException("Error invoking SqlProvider method (" - + providerType.getName() + "." + providerMethod.getName() - + "). Cannot invoke a method that holds " - + (bindParameterCount == 1 ? "named argument(@Param)" : "multiple arguments") - + " using a specifying parameterObject. In this case, please specify a 'java.util.Map' object."); + throw new BuilderException("Cannot invoke SqlProvider method '" + providerMethod + + "' with specify parameter '" + (parameterObject == null ? null : parameterObject.getClass()) + + "' because SqlProvider method arguments for '" + mapperMethod + "' is an invalid combination."); } Class parameterType = parameterObject == null ? Object.class : parameterObject.getClass(); return languageDriver.createSqlSource(configuration, sql, parameterType); } catch (BuilderException e) { throw e; } catch (Exception e) { - throw new BuilderException("Error invoking SqlProvider method (" - + providerType.getName() + "." + providerMethod.getName() - + "). Cause: " + e, e); + throw new BuilderException("Error invoking SqlProvider method '" + providerMethod + + "' with specify parameter '" + (parameterObject == null ? null : parameterObject.getClass()) + "'. Cause: " + extractRootCause(e), e); + } + } + + private Throwable extractRootCause(Exception e) { + Throwable cause = e; + while(cause.getCause() != null) { + cause = e.getCause(); } + return cause; } private Object[] extractProviderMethodArguments(Object parameterObject) { diff --git a/src/test/java/org/apache/ibatis/submitted/sqlprovider/Mapper.java b/src/test/java/org/apache/ibatis/submitted/sqlprovider/Mapper.java index c72a5ab4fe6..501142b0a43 100644 --- a/src/test/java/org/apache/ibatis/submitted/sqlprovider/Mapper.java +++ b/src/test/java/org/apache/ibatis/submitted/sqlprovider/Mapper.java @@ -1,5 +1,5 @@ /** - * Copyright 2009-2018 the original author or authors. + * Copyright 2009-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,6 +41,9 @@ public interface Mapper extends BaseMapper { @SelectProvider(type = OurSqlBuilder.class, method = "buildGetUsersByCriteriaMapQuery") List getUsersByCriteriaMap(Map criteria); + @SelectProvider(type = OurSqlBuilder.class, method = "buildGetUsersByCriteriaMapWithParamQuery") + List getUsersByCriteriaMapWithParam(Map criteria); + @SelectProvider(type = OurSqlBuilder.class, method = "buildGetUsersByNameQuery") List getUsersByName(String name, String orderByColumn); diff --git a/src/test/java/org/apache/ibatis/submitted/sqlprovider/OurSqlBuilder.java b/src/test/java/org/apache/ibatis/submitted/sqlprovider/OurSqlBuilder.java index 6f31de2294c..edda4f028bb 100644 --- a/src/test/java/org/apache/ibatis/submitted/sqlprovider/OurSqlBuilder.java +++ b/src/test/java/org/apache/ibatis/submitted/sqlprovider/OurSqlBuilder.java @@ -84,6 +84,19 @@ public String buildGetUsersByCriteriaMapQuery(final Map criteria }}.toString(); } + public String buildGetUsersByCriteriaMapWithParamQuery(@Param("id") Integer id, @Param("name") String name) { + return new SQL() {{ + SELECT("*"); + FROM("users"); + if (id != null) { + WHERE("id = #{id}"); + } + if (name != null) { + WHERE("name like #{name} || '%'"); + } + }}.toString(); + } + public String buildGetUsersByNameQuery(final String name, final String orderByColumn) { return new SQL(){{ SELECT("*"); diff --git a/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java b/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java index 0a2720a86d3..74e568df14c 100644 --- a/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java +++ b/src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java @@ -165,6 +165,30 @@ void shouldGetUsersByCriteriaMap() { } } + @Test + void shouldGetUsersByCriteriaMapWithParam() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + Mapper mapper = sqlSession.getMapper(Mapper.class); + { + Map criteria = new HashMap<>(); + criteria.put("id", 1); + List users = mapper.getUsersByCriteriaMapWithParam(criteria); + assertEquals(1, users.size()); + assertEquals("User1", users.get(0).getName()); + } + { + Map criteria = new HashMap<>(); + criteria.put("name", "User"); + List users = mapper.getUsersByCriteriaMapWithParam(criteria); + assertEquals(4, users.size()); + assertEquals("User1", users.get(0).getName()); + assertEquals("User2", users.get(1).getName()); + assertEquals("User3", users.get(2).getName()); + assertEquals("User4", users.get(3).getName()); + } + } + } + // Test for multiple parameter without @Param @Test void shouldGetUsersByName() { @@ -330,7 +354,7 @@ void notSupportParameterObjectOnMultipleArguments() throws NoSuchMethodException .getBoundSql(new Object()); fail(); } catch (BuilderException e) { - assertTrue(e.getMessage().contains("Error invoking SqlProvider method (org.apache.ibatis.submitted.sqlprovider.OurSqlBuilder.buildGetUsersByNameQuery). Cannot invoke a method that holds multiple arguments using a specifying parameterObject. In this case, please specify a 'java.util.Map' object.")); + assertTrue(e.getMessage().contains("Error invoking SqlProvider method 'public java.lang.String org.apache.ibatis.submitted.sqlprovider.OurSqlBuilder.buildGetUsersByNameQuery(java.lang.String,java.lang.String)' with specify parameter 'class java.lang.Object'. Cause: java.lang.IllegalArgumentException: wrong number of arguments")); } } @@ -344,7 +368,7 @@ void notSupportParameterObjectOnNamedArgument() throws NoSuchMethodException { .getBoundSql(new Object()); fail(); } catch (BuilderException e) { - assertTrue(e.getMessage().contains("Error invoking SqlProvider method (org.apache.ibatis.submitted.sqlprovider.OurSqlBuilder.buildGetUsersByNameWithParamNameQuery). Cannot invoke a method that holds named argument(@Param) using a specifying parameterObject. In this case, please specify a 'java.util.Map' object.")); + assertTrue(e.getMessage().contains("Error invoking SqlProvider method 'public java.lang.String org.apache.ibatis.submitted.sqlprovider.OurSqlBuilder.buildGetUsersByNameWithParamNameQuery(java.lang.String)' with specify parameter 'class java.lang.Object'. Cause: java.lang.IllegalArgumentException: argument type mismatch")); } } @@ -358,7 +382,21 @@ void invokeError() throws NoSuchMethodException { .getBoundSql(new Object()); fail(); } catch (BuilderException e) { - assertTrue(e.getMessage().contains("Error invoking SqlProvider method (org.apache.ibatis.submitted.sqlprovider.SqlProviderTest$ErrorSqlBuilder.invokeError). Cause: java.lang.reflect.InvocationTargetException")); + assertTrue(e.getMessage().contains("Error invoking SqlProvider method 'public java.lang.String org.apache.ibatis.submitted.sqlprovider.SqlProviderTest$ErrorSqlBuilder.invokeError()' with specify parameter 'class java.lang.Object'. Cause: java.lang.UnsupportedOperationException: invokeError")); + } + } + + @Test + void invalidArgumentsCombination() throws NoSuchMethodException { + try { + Class mapperType = ErrorMapper.class; + Method mapperMethod = mapperType.getMethod("invalidArgumentsCombination", String.class); + new ProviderSqlSource(new Configuration(), + mapperMethod.getAnnotation(DeleteProvider.class), mapperType, mapperMethod) + .getBoundSql("foo"); + fail(); + } catch (BuilderException e) { + assertTrue(e.getMessage().contains("Cannot invoke SqlProvider method 'public java.lang.String org.apache.ibatis.submitted.sqlprovider.SqlProviderTest$ErrorSqlBuilder.invalidArgumentsCombination(org.apache.ibatis.builder.annotation.ProviderContext,java.lang.String,java.lang.String)' with specify parameter 'class java.lang.String' because SqlProvider method arguments for 'public abstract void org.apache.ibatis.submitted.sqlprovider.SqlProviderTest$ErrorMapper.invalidArgumentsCombination(java.lang.String)' is an invalid combination.")); } } @@ -464,6 +502,96 @@ void staticMethodOneArgument() { } } + @Test + void staticMethodOnePrimitiveByteArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals((byte) 10, mapper.onePrimitiveByteArgument((byte) 10)); + } + } + + @Test + void staticMethodOnePrimitiveShortArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals((short) 10, mapper.onePrimitiveShortArgument((short) 10)); + } + } + + @Test + void staticMethodOnePrimitiveIntArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10, mapper.onePrimitiveIntArgument(10)); + } + } + + @Test + void staticMethodOnePrimitiveLongArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10L, mapper.onePrimitiveLongArgument(10L)); + } + } + + @Test + void staticMethodOnePrimitiveFloatArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10.1F, mapper.onePrimitiveFloatArgument(10.1F)); + } + } + + @Test + void staticMethodOnePrimitiveDoubleArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10.1D, mapper.onePrimitiveDoubleArgument(10.1D)); + } + } + + @Test + void staticMethodOnePrimitiveBooleanArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertTrue(mapper.onePrimitiveBooleanArgument(true)); + } + } + + @Test + void staticMethodOnePrimitiveCharArgument() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals('A', mapper.onePrimitiveCharArgument('A')); + } + } + + @Test + void boxing() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(10, mapper.boxing(10)); + } + } + + @Test + void unboxing() { + try (SqlSession sqlSession = sqlSessionFactory.openSession()) { + StaticMethodSqlProviderMapper mapper = + sqlSession.getMapper(StaticMethodSqlProviderMapper.class); + assertEquals(100, mapper.unboxing(100)); + } + } + @Test void staticMethodMultipleArgument() { try (SqlSession sqlSession = sqlSessionFactory.openSession()) { @@ -509,6 +637,10 @@ public interface ErrorMapper { @DeleteProvider(value = String.class, type = Integer.class) void differentTypeAndValue(); + + @DeleteProvider(type = ErrorSqlBuilder.class, method = "invalidArgumentsCombination") + void invalidArgumentsCombination(String value); + } @SuppressWarnings("unused") @@ -532,6 +664,10 @@ public String invokeError() { public String multipleProviderContext(ProviderContext providerContext1, ProviderContext providerContext2) { throw new UnsupportedOperationException("multipleProviderContext"); } + + public String invalidArgumentsCombination(ProviderContext providerContext, String value, String unnecessaryArgument) { + return ""; + } } public interface StaticMethodSqlProviderMapper { @@ -541,6 +677,36 @@ public interface StaticMethodSqlProviderMapper { @SelectProvider(type = SqlProvider.class, method = "oneArgument") int oneArgument(Integer value); + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveByteArgument") + byte onePrimitiveByteArgument(byte value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveShortArgument") + short onePrimitiveShortArgument(short value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveIntArgument") + int onePrimitiveIntArgument(int value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveLongArgument") + long onePrimitiveLongArgument(long value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveFloatArgument") + float onePrimitiveFloatArgument(float value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveDoubleArgument") + double onePrimitiveDoubleArgument(double value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveBooleanArgument") + boolean onePrimitiveBooleanArgument(boolean value); + + @SelectProvider(type = SqlProvider.class, method = "onePrimitiveCharArgument") + char onePrimitiveCharArgument(char value); + + @SelectProvider(type = SqlProvider.class, method = "boxing") + int boxing(int value); + + @SelectProvider(type = SqlProvider.class, method = "unboxing") + int unboxing(Integer value); + @SelectProvider(type = SqlProvider.class, method = "multipleArgument") int multipleArgument(@Param("value1") Integer value1, @Param("value2") Integer value2); @@ -562,6 +728,56 @@ public static StringBuilder oneArgument(Integer value) { .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); } + public static StringBuilder onePrimitiveByteArgument(byte value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveShortArgument(short value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveIntArgument(int value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveLongArgument(long value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveFloatArgument(float value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveDoubleArgument(double value) { + return new StringBuilder().append("SELECT ").append(value) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveBooleanArgument(boolean value) { + return new StringBuilder().append("SELECT ").append(value ? 1 : 0) + .append(" FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder onePrimitiveCharArgument(char value) { + return new StringBuilder().append("SELECT '").append(value) + .append("' FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder boxing(Integer value) { + return new StringBuilder().append("SELECT '").append(value) + .append("' FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + + public static StringBuilder unboxing(int value) { + return new StringBuilder().append("SELECT '").append(value) + .append("' FROM INFORMATION_SCHEMA.SYSTEM_USERS"); + } + public static CharSequence multipleArgument(@Param("value1") Integer value1, @Param("value2") Integer value2) { return "SELECT " + (value1 + value2) + " FROM INFORMATION_SCHEMA.SYSTEM_USERS";