Skip to content

Commit 6908049

Browse files
committed
Fix broken backward compatibility for gh-1604
1 parent 0a9f940 commit 6908049

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

src/main/java/org/apache/ibatis/builder/annotation/ProviderSqlSource.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,11 @@ public BoundSql getBoundSql(Object parameterObject) {
116116

117117
private SqlSource createSqlSource(Object parameterObject) {
118118
try {
119-
int bindParameterCount = providerMethodParameterTypes.length - (providerContext == null ? 0 : 1);
120119
String sql;
121120
if (parameterObject instanceof Map) {
122-
if (bindParameterCount == 1 && providerMethodParameterTypes[0] == Map.class) {
123-
sql = invokeProviderMethod(extractProviderMethodArguments(parameterObject));
124-
} else {
125-
@SuppressWarnings("unchecked")
126-
Map<String, Object> params = (Map<String, Object>) parameterObject;
127-
sql = invokeProviderMethod(extractProviderMethodArguments(params, providerMethodArgumentNames));
128-
}
121+
@SuppressWarnings("unchecked")
122+
Map<String, Object> params = (Map<String, Object>) parameterObject;
123+
sql = invokeProviderMethod(extractProviderMethodArguments(params, providerMethodArgumentNames, providerMethodParameterTypes));
129124
} else if (providerMethodParameterTypes.length == 0) {
130125
sql = invokeProviderMethod();
131126
} else if (providerMethodParameterTypes.length == 1) {
@@ -170,11 +165,13 @@ private Object[] extractProviderMethodArguments(Object parameterObject) {
170165
}
171166
}
172167

173-
private Object[] extractProviderMethodArguments(Map<String, Object> params, String[] argumentNames) {
168+
private Object[] extractProviderMethodArguments(Map<String, Object> params, String[] argumentNames, Class<?>[] argumentTypes) {
174169
Object[] args = new Object[argumentNames.length];
175170
for (int i = 0; i < args.length; i++) {
176171
if (providerContextIndex != null && providerContextIndex == i) {
177172
args[i] = providerContext;
173+
} else if(argumentTypes[i].isAssignableFrom(params.getClass())) {
174+
args[i] = params;
178175
} else {
179176
args[i] = params.get(argumentNames[i]);
180177
}

src/test/java/org/apache/ibatis/submitted/sqlprovider/SqlProviderTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.ibatis.annotations.DeleteProvider;
3333
import org.apache.ibatis.annotations.Param;
3434
import org.apache.ibatis.annotations.SelectProvider;
35+
import org.apache.ibatis.binding.MapperMethod;
3536
import org.apache.ibatis.builder.BuilderException;
3637
import org.apache.ibatis.builder.annotation.ProviderContext;
3738
import org.apache.ibatis.builder.annotation.ProviderSqlSource;
@@ -619,6 +620,24 @@ void staticMethodOneArgumentAndProviderContext() {
619620
}
620621
}
621622

623+
@Test
624+
void mapAndProviderContext() {
625+
try (SqlSession sqlSession = sqlSessionFactory.openSession()) {
626+
StaticMethodSqlProviderMapper mapper =
627+
sqlSession.getMapper(StaticMethodSqlProviderMapper.class);
628+
assertEquals("mybatis", mapper.mapAndProviderContext("mybatis"));
629+
}
630+
}
631+
632+
@Test
633+
void providerContextAndMap() {
634+
try (SqlSession sqlSession = sqlSessionFactory.openSession()) {
635+
StaticMethodSqlProviderMapper mapper =
636+
sqlSession.getMapper(StaticMethodSqlProviderMapper.class);
637+
assertEquals("mybatis", mapper.providerContextAndParamMap("mybatis"));
638+
}
639+
}
640+
622641
public interface ErrorMapper {
623642
@SelectProvider(type = ErrorSqlBuilder.class, method = "methodNotFound")
624643
void methodNotFound();
@@ -716,6 +735,12 @@ public interface StaticMethodSqlProviderMapper {
716735
@SelectProvider(type = SqlProvider.class, method = "oneArgumentAndProviderContext")
717736
String oneArgumentAndProviderContext(Integer value);
718737

738+
@SelectProvider(type = SqlProvider.class, method = "mapAndProviderContext")
739+
String mapAndProviderContext(@Param("value") String value);
740+
741+
@SelectProvider(type = SqlProvider.class, method = "providerContextAndParamMap")
742+
String providerContextAndParamMap(@Param("value") String value);
743+
719744
@SuppressWarnings("unused")
720745
class SqlProvider {
721746

@@ -793,6 +818,14 @@ public static String oneArgumentAndProviderContext(Integer value, ProviderContex
793818
+ "' FROM INFORMATION_SCHEMA.SYSTEM_USERS";
794819
}
795820

821+
public static String mapAndProviderContext(Map<String, Object> map, ProviderContext context) {
822+
return "SELECT '" + map.get("value") + "' FROM INFORMATION_SCHEMA.SYSTEM_USERS";
823+
}
824+
825+
public static String providerContextAndParamMap(ProviderContext context, MapperMethod.ParamMap<Object> map) {
826+
return "SELECT '" + map.get("value") + "' FROM INFORMATION_SCHEMA.SYSTEM_USERS";
827+
}
828+
796829
}
797830

798831
}

0 commit comments

Comments
 (0)