From b02f49a39b65e52cf117b71bc735ba0c39b08d84 Mon Sep 17 00:00:00 2001 From: Hccake Date: Tue, 9 Nov 2021 11:09:29 +0800 Subject: [PATCH 1/3] :bug: fix gitee I4FP6E, right join bug --- .../inner/TenantLineInnerInterceptor.java | 96 +++++++++++++++---- .../inner/TenantLineInnerInterceptorTest.java | 18 +++- 2 files changed, 92 insertions(+), 22 deletions(-) diff --git a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java index 863310f4e8..bb3fd1ce1f 100644 --- a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java +++ b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java @@ -41,6 +41,7 @@ import java.sql.Connection; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Collection; import java.util.Deque; import java.util.LinkedList; @@ -235,26 +236,37 @@ protected void appendSelectItem(List selectItems) { */ protected void processPlainSelect(PlainSelect plainSelect) { FromItem fromItem = plainSelect.getFromItem(); + + //#3087 github + List selectItems = plainSelect.getSelectItems(); + if (CollectionUtils.isNotEmpty(selectItems)) { + selectItems.forEach(this::processSelectItem); + } + + // #I4FP6E gitee:右连接查询时,where 条件需要过滤 + List rightJointTables; + List joins = plainSelect.getJoins(); + if (CollectionUtils.isNotEmpty(joins)) { + rightJointTables = processJoins(joins); + }else { + rightJointTables = new ArrayList<>(); + } + Expression where = plainSelect.getWhere(); processWhereSubSelect(where); if (fromItem instanceof Table) { Table fromTable = (Table) fromItem; - if (!tenantLineHandler.ignoreTable(fromTable.getName())) { + boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName()); + if (needIgnore) { + plainSelect.setWhere(builderExpression(where, null, rightJointTables)); + }else { //#1186 github - plainSelect.setWhere(builderExpression(where, fromTable)); + plainSelect.setWhere(builderExpression(where, fromTable, rightJointTables)); } } else { processFromItem(fromItem); } - //#3087 github - List selectItems = plainSelect.getSelectItems(); - if (CollectionUtils.isNotEmpty(selectItems)) { - selectItems.forEach(this::processSelectItem); - } - List joins = plainSelect.getJoins(); - if (CollectionUtils.isNotEmpty(joins)) { - processJoins(joins); - } + } /** @@ -379,8 +391,12 @@ protected void processFromItem(FromItem fromItem) { * 处理 joins * * @param joins join 集合 + * @return List
右连接查询的 Table 列表 */ - private void processJoins(List joins) { + private List
processJoins(List joins) { + + List
rightJointTables = new ArrayList<>(); + //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名 Deque
tables = new LinkedList<>(); for (Join join : joins) { @@ -390,13 +406,20 @@ private void processJoins(List joins) { Table fromTable = (Table) fromItem; // 获取 join 尾缀的 on 表达式列表 Collection originOnExpressions = join.getOnExpressions(); + + // 当前表是否忽略 + boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName()); + // 如果不要忽略,且是右连接,则记录下当前表 + if (!needIgnore && join.isRight()) { + rightJointTables.add(fromTable); + } + // 正常 join on 表达式只有一个,立刻处理 if (originOnExpressions.size() == 1) { processJoin(join); continue; } - // 当前表是否忽略 - boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName()); + // 表名压栈,忽略的表压入 null,以便后续不处理 tables.push(needIgnore ? null : fromTable); // 尾缀多个 on 表达式的时候统一处理 @@ -417,6 +440,8 @@ private void processJoins(List joins) { processFromItem(fromItem); } } + + return rightJointTables; } /** @@ -441,16 +466,47 @@ protected void processJoin(Join join) { * 处理条件 */ protected Expression builderExpression(Expression currentExpression, Table table) { - EqualsTo equalsTo = new EqualsTo(); - equalsTo.setLeftExpression(this.getAliasColumn(table)); - equalsTo.setRightExpression(tenantLineHandler.getTenantId()); + return builderExpression(currentExpression, table, new ArrayList<>()); + } + + /** + * 处理条件 + */ + protected Expression builderExpression(Expression currentExpression, Table table, List
rightJointTables) { + // 没有表需要处理直接返回 + if(table == null && CollectionUtils.isEmpty(rightJointTables)){ + return currentExpression; + } + + // 当前需要处理的表 + List
tables = new ArrayList<>(); + if(table != null){ + tables.add(table); + } + tables.addAll(rightJointTables); + + // 租户 + Expression tenantId = tenantLineHandler.getTenantId(); + // 构造每张表的条件 + List equalsTos = tables.stream() + .map(item -> new EqualsTo(getAliasColumn(item), tenantId)) + .collect(Collectors.toList()); + // 注入的表达式 + Expression injectExpression = equalsTos.get(0); + // 如果有多表,则用 and 连接 + if(equalsTos.size() > 1){ + for (int i = 1; i < equalsTos.size(); i++) { + injectExpression = new AndExpression(injectExpression, equalsTos.get(i)); + } + } + if (currentExpression == null) { - return equalsTo; + return injectExpression; } if (currentExpression instanceof OrExpression) { - return new AndExpression(new Parenthesis(currentExpression), equalsTo); + return new AndExpression(new Parenthesis(currentExpression), injectExpression); } else { - return new AndExpression(currentExpression, equalsTo); + return new AndExpression(currentExpression, injectExpression); } } diff --git a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java index d0576a85e9..26f3a64cc4 100644 --- a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java +++ b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java @@ -187,14 +187,28 @@ void selectRightJoin() { "right join entity1 e1 on e1.id = e.id", "SELECT * FROM entity e " + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE e.tenant_id = 1"); + "WHERE e.tenant_id = 1 AND e1.tenant_id = 1"); + + assertSql("SELECT * FROM with_as_1 e " + + "right join entity1 e1 on e1.id = e.id", + "SELECT * FROM with_as_1 e " + + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "WHERE e1.tenant_id = 1"); assertSql("SELECT * FROM entity e " + "right join entity1 e1 on e1.id = e.id " + "WHERE e.id = ? OR e.name = ?", "SELECT * FROM entity e " + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e1.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "right join entity1 e1 on e1.id = e.id " + + "right join entity2 e2 on e1.id = e2.id ", + "SELECT * FROM entity e " + + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " + + "WHERE e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1"); } @Test From f8f209c261e5c0f3154faf17f1230ccf9b7133f1 Mon Sep 17 00:00:00 2001 From: Hccake Date: Fri, 12 Nov 2021 18:38:04 +0800 Subject: [PATCH 2/3] =?UTF-8?q?:ambulance:=20=E8=B0=83=E6=95=B4=E4=BA=86?= =?UTF-8?q?=E5=A4=9A=E7=A7=9F=E6=88=B7=20sql=20=E8=A7=A3=E6=9E=90=E6=B5=81?= =?UTF-8?q?=E7=A8=8B=EF=BC=8C=E4=BF=AE=E5=A4=8D=20right=20join=E3=80=81sub?= =?UTF-8?q?Join=20=E7=9A=84=E9=97=AE=E9=A2=98=EF=BC=8C=E4=BB=A5=E5=8F=8A?= =?UTF-8?q?=E4=BC=98=E5=8C=96=20innerJoin?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../inner/TenantLineInnerInterceptor.java | 189 ++++++++++-------- .../inner/TenantLineInnerInterceptorTest.java | 126 ++++++++++-- 2 files changed, 217 insertions(+), 98 deletions(-) diff --git a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java index bb3fd1ce1f..fb7eefdce8 100644 --- a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java +++ b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java @@ -41,8 +41,9 @@ import java.sql.Connection; import java.sql.SQLException; -import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Deque; import java.util.LinkedList; import java.util.List; @@ -235,38 +236,48 @@ protected void appendSelectItem(List selectItems) { * 处理 PlainSelect */ protected void processPlainSelect(PlainSelect plainSelect) { - FromItem fromItem = plainSelect.getFromItem(); - //#3087 github List selectItems = plainSelect.getSelectItems(); if (CollectionUtils.isNotEmpty(selectItems)) { selectItems.forEach(this::processSelectItem); } - // #I4FP6E gitee:右连接查询时,where 条件需要过滤 - List
rightJointTables; + // 处理 where 中的子查询 + Expression where = plainSelect.getWhere(); + processWhereSubSelect(where); + + // 处理 fromItem + FromItem fromItem = plainSelect.getFromItem(); + Table mainTable = processFromItem(fromItem); + + // 处理 join List joins = plainSelect.getJoins(); if (CollectionUtils.isNotEmpty(joins)) { - rightJointTables = processJoins(joins); - }else { - rightJointTables = new ArrayList<>(); + mainTable = processJoins(mainTable, joins); } - Expression where = plainSelect.getWhere(); - processWhereSubSelect(where); + // 当有 mainTable 时,进行 where 条件追加 + if (mainTable != null) { + plainSelect.setWhere(builderExpression(where, Collections.singletonList(mainTable))); + } + } + + private Table processFromItem(FromItem fromItem) { + Table mainTable = null; + // 无 join 时的处理逻辑 if (fromItem instanceof Table) { Table fromTable = (Table) fromItem; - boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName()); - if (needIgnore) { - plainSelect.setWhere(builderExpression(where, null, rightJointTables)); - }else { - //#1186 github - plainSelect.setWhere(builderExpression(where, fromTable, rightJointTables)); + if (!tenantLineHandler.ignoreTable(fromTable.getName())) { + mainTable = fromTable; } + } else if (fromItem instanceof SubJoin) { + // SubJoin 类型则还需要添加上 where 条件 + mainTable = processSubJoin((SubJoin) fromItem); } else { - processFromItem(fromItem); + // 处理下 fromItem + processOtherFromItem(fromItem); } - + return mainTable; } /** @@ -294,7 +305,7 @@ protected void processWhereSubSelect(Expression where) { return; } if (where instanceof FromItem) { - processFromItem((FromItem) where); + processOtherFromItem((FromItem) where); return; } if (where.toString().indexOf("SELECT") > 0) { @@ -360,16 +371,8 @@ protected void processFunction(Function function) { /** * 处理子查询等 */ - protected void processFromItem(FromItem fromItem) { - if (fromItem instanceof SubJoin) { - SubJoin subJoin = (SubJoin) fromItem; - if (subJoin.getJoinList() != null) { - processJoins(subJoin.getJoinList()); - } - if (subJoin.getLeft() != null) { - processFromItem(subJoin.getLeft()); - } - } else if (fromItem instanceof SubSelect) { + protected void processOtherFromItem(FromItem fromItem) { + if (fromItem instanceof SubSelect) { SubSelect subSelect = (SubSelect) fromItem; if (subSelect.getSelectBody() != null) { processSelectBody(subSelect.getSelectBody()); @@ -387,104 +390,118 @@ protected void processFromItem(FromItem fromItem) { } } + /** + * 处理 sub join + * + * @param subJoin subJoin + * @return Table subJoin 中的主表 + */ + private Table processSubJoin(SubJoin subJoin) { + Table mainTable = null; + if (subJoin.getJoinList() != null) { + mainTable = processFromItem(subJoin.getLeft()); + mainTable = processJoins(mainTable, subJoin.getJoinList()); + } + return mainTable; + } + /** * 处理 joins * - * @param joins join 集合 + * @param fromTable 可以为 null + * @param joins join 集合 * @return List
右连接查询的 Table 列表 */ - private List
processJoins(List joins) { - - List
rightJointTables = new ArrayList<>(); + private Table processJoins(Table fromTable, List joins) { + // join 表达式中最终的主表 + Table mainTable = fromTable; + // 当前 join 的左表 + Table leftTable = fromTable; //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名 - Deque
tables = new LinkedList<>(); + Deque> onTableDeque = new LinkedList<>(); for (Join join : joins) { + List
onTables = null; // 处理 on 表达式 - FromItem fromItem = join.getRightItem(); - if (fromItem instanceof Table) { - Table fromTable = (Table) fromItem; + FromItem joinItem = join.getRightItem(); + + // 获取当前 join 的表,subJoint 可以看作是一张表 + Table joinTable = null; + if (joinItem instanceof Table) { + joinTable = (Table) joinItem; + } else if (joinItem instanceof SubJoin) { + joinTable = processSubJoin((SubJoin) joinItem); + } + + if (joinTable != null) { // 获取 join 尾缀的 on 表达式列表 Collection originOnExpressions = join.getOnExpressions(); // 当前表是否忽略 - boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName()); + boolean joinTableNeedIgnore = tenantLineHandler.ignoreTable(joinTable.getName()); + // 如果不要忽略,且是右连接,则记录下当前表 - if (!needIgnore && join.isRight()) { - rightJointTables.add(fromTable); + if (join.isRight()) { + mainTable = joinTableNeedIgnore ? null : joinTable; + if (leftTable != null) { + onTables = Collections.singletonList(leftTable); + } + } else if (join.isLeft()) { + if (!joinTableNeedIgnore) { + onTables = Collections.singletonList(joinTable); + } + } else if (join.isInner()) { + if (mainTable == null) { + onTables = Collections.singletonList(joinTable); + } else { + onTables = Arrays.asList(mainTable, joinTable); + } + mainTable = null; } // 正常 join on 表达式只有一个,立刻处理 - if (originOnExpressions.size() == 1) { - processJoin(join); + if (originOnExpressions.size() == 1 && onTables != null) { + List onExpressions = new LinkedList<>(); + onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables)); + join.setOnExpressions(onExpressions); + leftTable = joinTable; continue; } // 表名压栈,忽略的表压入 null,以便后续不处理 - tables.push(needIgnore ? null : fromTable); + onTableDeque.push(onTables); // 尾缀多个 on 表达式的时候统一处理 if (originOnExpressions.size() > 1) { Collection onExpressions = new LinkedList<>(); for (Expression originOnExpression : originOnExpressions) { - Table currentTable = tables.poll(); - if (currentTable == null) { + List
currentTableList = onTableDeque.poll(); + if (CollectionUtils.isEmpty(currentTableList)) { onExpressions.add(originOnExpression); } else { - onExpressions.add(builderExpression(originOnExpression, currentTable)); + onExpressions.add(builderExpression(originOnExpression, currentTableList)); } } join.setOnExpressions(onExpressions); } + leftTable = joinTable; } else { - // 处理右边连接的子表达式 - processFromItem(fromItem); + processOtherFromItem(joinItem); + leftTable = null; } - } - return rightJointTables; - } - - /** - * 处理联接语句 - */ - protected void processJoin(Join join) { - if (join.getRightItem() instanceof Table) { - Table fromTable = (Table) join.getRightItem(); - if (tenantLineHandler.ignoreTable(fromTable.getName())) { - // 过滤退出执行 - return; - } - // 走到这里说明 on 表达式肯定只有一个 - Collection originOnExpressions = join.getOnExpressions(); - List onExpressions = new LinkedList<>(); - onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable)); - join.setOnExpressions(onExpressions); } - } - /** - * 处理条件 - */ - protected Expression builderExpression(Expression currentExpression, Table table) { - return builderExpression(currentExpression, table, new ArrayList<>()); + return mainTable; } /** * 处理条件 */ - protected Expression builderExpression(Expression currentExpression, Table table, List
rightJointTables) { - // 没有表需要处理直接返回 - if(table == null && CollectionUtils.isEmpty(rightJointTables)){ - return currentExpression; - } - - // 当前需要处理的表 - List
tables = new ArrayList<>(); - if(table != null){ - tables.add(table); + protected Expression builderExpression(Expression currentExpression, List
tables) { + // 没有表需要处理直接返回 + if (CollectionUtils.isEmpty(tables)) { + return currentExpression; } - tables.addAll(rightJointTables); - // 租户 Expression tenantId = tenantLineHandler.getTenantId(); // 构造每张表的条件 @@ -494,7 +511,7 @@ protected Expression builderExpression(Expression currentExpression, Table table // 注入的表达式 Expression injectExpression = equalsTos.get(0); // 如果有多表,则用 and 连接 - if(equalsTos.size() > 1){ + if (equalsTos.size() > 1) { for (int i = 1; i < equalsTos.size(); i++) { injectExpression = new AndExpression(injectExpression, equalsTos.get(i)); } diff --git a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java index 26f3a64cc4..91acb2f85f 100644 --- a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java +++ b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java @@ -178,6 +178,14 @@ void selectLeftJoin() { "SELECT * FROM entity e " + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "left join entity1 e1 on e1.id = e.id " + + "left join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " + + "WHERE e.tenant_id = 1"); } @Test @@ -186,31 +194,125 @@ void selectRightJoin() { assertSql("SELECT * FROM entity e " + "right join entity1 e1 on e1.id = e.id", "SELECT * FROM entity e " + - "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE e.tenant_id = 1 AND e1.tenant_id = 1"); + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "WHERE e1.tenant_id = 1"); assertSql("SELECT * FROM with_as_1 e " + "right join entity1 e1 on e1.id = e.id", "SELECT * FROM with_as_1 e " + - "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "RIGHT JOIN entity1 e1 ON e1.id = e.id " + "WHERE e1.tenant_id = 1"); assertSql("SELECT * FROM entity e " + "right join entity1 e1 on e1.id = e.id " + "WHERE e.id = ? OR e.name = ?", "SELECT * FROM entity e " + - "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e1.tenant_id = 1"); + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1"); assertSql("SELECT * FROM entity e " + "right join entity1 e1 on e1.id = e.id " + "right join entity2 e2 on e1.id = e2.id ", "SELECT * FROM entity e " + - "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " + - "WHERE e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1"); + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " + + "WHERE e2.tenant_id = 1"); } + @Test + void selectMixJoin(){ + assertSql("SELECT * FROM entity e " + + "right join entity1 e1 on e1.id = e.id " + + "left join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " + + "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " + + "WHERE e1.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "left join entity1 e1 on e1.id = e.id " + + "right join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " + + "WHERE e2.tenant_id = 1"); + + assertSql("SELECT * FROM entity e " + + "left join entity1 e1 on e1.id = e.id " + + "inner join entity2 e2 on e1.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + + "INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1"); + } + + + @Test + void selectJoinSubSelect(){ + assertSql("select * from (select * from entity) e1 " + + "left join entity2 e2 on e1.id = e2.id", + "SELECT * FROM (SELECT * FROM entity WHERE tenant_id = 1) e1 " + + "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1"); + + assertSql("select * from entity1 e1 " + + "left join (select * from entity2) e2 " + + "on e1.id = e2.id", + "SELECT * FROM entity1 e1 " + + "LEFT JOIN (SELECT * FROM entity2 WHERE tenant_id = 1) e2 " + + "ON e1.id = e2.id " + + "WHERE e1.tenant_id = 1"); + } + + @Test + void selectSubJoin(){ + + assertSql("select * FROM " + + "(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)", + "SELECT * FROM " + + "(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " + + "WHERE e2.tenant_id = 1"); + + assertSql("select * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)", + "SELECT * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "WHERE e1.tenant_id = 1"); + + + assertSql("select * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " + + "right join entity3 e3 on e1.id = e3.id", + "SELECT * FROM " + + "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " + + "WHERE e3.tenant_id = 1"); + + + assertSql("select * FROM entity e " + + "LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " + + "on e.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " + + "ON e.id = e2.id AND e2.tenant_id = 1 " + + "WHERE e.tenant_id = 1"); + + assertSql("select * FROM entity e " + + "LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " + + "on e.id = e2.id", + "SELECT * FROM entity e " + + "LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "ON e.id = e2.id AND e1.tenant_id = 1 " + + "WHERE e.tenant_id = 1"); + + assertSql("select * FROM entity e " + + "RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " + + "on e.id = e2.id", + "SELECT * FROM entity e " + + "RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " + + "ON e.id = e2.id AND e.tenant_id = 1 " + + "WHERE e1.tenant_id = 1"); + } + + @Test void selectLeftJoinMultipleTrailingOn() { // 多个 on 尾缀的 @@ -244,15 +346,15 @@ void selectInnerJoin() { "inner join entity1 e1 on e1.id = e.id " + "WHERE e.id = ? OR e.name = ?", "SELECT * FROM entity e " + - "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " + + "WHERE e.id = ? OR e.name = ?"); assertSql("SELECT * FROM entity e " + "inner join entity1 e1 on e1.id = e.id " + "WHERE (e.id = ? OR e.name = ?)", "SELECT * FROM entity e " + - "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + - "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " + + "WHERE (e.id = ? OR e.name = ?)"); // 垃圾 inner join todo // assertSql("SELECT * FROM entity,entity1 " + From 5ba2470814074a3767a2e3928b7ed0e777e4b5c7 Mon Sep 17 00:00:00 2001 From: Hccake Date: Sat, 13 Nov 2021 16:58:01 +0800 Subject: [PATCH 3/3] =?UTF-8?q?:bug:=20=E8=A7=A3=E5=86=B3=E5=B9=B4?= =?UTF-8?q?=E4=B9=85=E5=A4=B1=E4=BF=AE=E7=9A=84=E9=9A=90=E5=BC=8F=E5=86=85?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../inner/TenantLineInnerInterceptor.java | 96 +++++++++++++------ .../inner/TenantLineInnerInterceptorTest.java | 74 +++++++++----- 2 files changed, 116 insertions(+), 54 deletions(-) diff --git a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java index fb7eefdce8..64100fe90a 100644 --- a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java +++ b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java @@ -41,6 +41,7 @@ import java.sql.Connection; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -248,36 +249,43 @@ protected void processPlainSelect(PlainSelect plainSelect) { // 处理 fromItem FromItem fromItem = plainSelect.getFromItem(); - Table mainTable = processFromItem(fromItem); + List
list = processFromItem(fromItem); + List
mainTables = new ArrayList<>(list); // 处理 join List joins = plainSelect.getJoins(); if (CollectionUtils.isNotEmpty(joins)) { - mainTable = processJoins(mainTable, joins); + mainTables = processJoins(mainTables, joins); } // 当有 mainTable 时,进行 where 条件追加 - if (mainTable != null) { - plainSelect.setWhere(builderExpression(where, Collections.singletonList(mainTable))); + if (CollectionUtils.isNotEmpty(mainTables)) { + plainSelect.setWhere(builderExpression(where, mainTables)); } } - private Table processFromItem(FromItem fromItem) { - Table mainTable = null; + private List
processFromItem(FromItem fromItem) { + // 处理括号括起来的表达式 + while (fromItem instanceof ParenthesisFromItem) { + fromItem = ((ParenthesisFromItem) fromItem).getFromItem(); + } + + List
mainTables = new ArrayList<>(); // 无 join 时的处理逻辑 if (fromItem instanceof Table) { Table fromTable = (Table) fromItem; if (!tenantLineHandler.ignoreTable(fromTable.getName())) { - mainTable = fromTable; + mainTables.add(fromTable); } } else if (fromItem instanceof SubJoin) { // SubJoin 类型则还需要添加上 where 条件 - mainTable = processSubJoin((SubJoin) fromItem); + List
tables = processSubJoin((SubJoin) fromItem); + mainTables.addAll(tables); } else { // 处理下 fromItem processOtherFromItem(fromItem); } - return mainTable; + return mainTables; } /** @@ -372,6 +380,11 @@ protected void processFunction(Function function) { * 处理子查询等 */ protected void processOtherFromItem(FromItem fromItem) { + // 去除括号 + while (fromItem instanceof ParenthesisFromItem) { + fromItem = ((ParenthesisFromItem) fromItem).getFromItem(); + } + if (fromItem instanceof SubSelect) { SubSelect subSelect = (SubSelect) fromItem; if (subSelect.getSelectBody() != null) { @@ -396,50 +409,65 @@ protected void processOtherFromItem(FromItem fromItem) { * @param subJoin subJoin * @return Table subJoin 中的主表 */ - private Table processSubJoin(SubJoin subJoin) { - Table mainTable = null; + private List
processSubJoin(SubJoin subJoin) { + List
mainTables = new ArrayList<>(); if (subJoin.getJoinList() != null) { - mainTable = processFromItem(subJoin.getLeft()); - mainTable = processJoins(mainTable, subJoin.getJoinList()); + List
list = processFromItem(subJoin.getLeft()); + mainTables.addAll(list); + mainTables = processJoins(mainTables, subJoin.getJoinList()); } - return mainTable; + return mainTables; } /** * 处理 joins * - * @param fromTable 可以为 null - * @param joins join 集合 + * @param mainTables 可以为 null + * @param joins join 集合 * @return List
右连接查询的 Table 列表 */ - private Table processJoins(Table fromTable, List joins) { + private List
processJoins(List
mainTables, List joins) { + if (mainTables == null) { + mainTables = new ArrayList<>(); + } + // join 表达式中最终的主表 - Table mainTable = fromTable; + Table mainTable = null; // 当前 join 的左表 - Table leftTable = fromTable; + Table leftTable = null; + if (mainTables.size() == 1) { + mainTable = mainTables.get(0); + leftTable = mainTable; + } //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名 Deque> onTableDeque = new LinkedList<>(); for (Join join : joins) { - List
onTables = null; // 处理 on 表达式 FromItem joinItem = join.getRightItem(); // 获取当前 join 的表,subJoint 可以看作是一张表 - Table joinTable = null; + List
joinTables = null; if (joinItem instanceof Table) { - joinTable = (Table) joinItem; + joinTables = new ArrayList<>(); + joinTables.add((Table) joinItem); } else if (joinItem instanceof SubJoin) { - joinTable = processSubJoin((SubJoin) joinItem); + joinTables = processSubJoin((SubJoin) joinItem); } - if (joinTable != null) { - // 获取 join 尾缀的 on 表达式列表 - Collection originOnExpressions = join.getOnExpressions(); + if (joinTables != null) { + + // 如果是隐式内连接 + if (join.isSimple()) { + mainTables.addAll(joinTables); + continue; + } // 当前表是否忽略 + Table joinTable = joinTables.get(0); boolean joinTableNeedIgnore = tenantLineHandler.ignoreTable(joinTable.getName()); + List
onTables = null; // 如果不要忽略,且是右连接,则记录下当前表 if (join.isRight()) { mainTable = joinTableNeedIgnore ? null : joinTable; @@ -458,7 +486,13 @@ private Table processJoins(Table fromTable, List joins) { } mainTable = null; } + mainTables = new ArrayList<>(); + if (mainTable != null) { + mainTables.add(mainTable); + } + // 获取 join 尾缀的 on 表达式列表 + Collection originOnExpressions = join.getOnExpressions(); // 正常 join on 表达式只有一个,立刻处理 if (originOnExpressions.size() == 1 && onTables != null) { List onExpressions = new LinkedList<>(); @@ -467,7 +501,6 @@ private Table processJoins(Table fromTable, List joins) { leftTable = joinTable; continue; } - // 表名压栈,忽略的表压入 null,以便后续不处理 onTableDeque.push(onTables); // 尾缀多个 on 表达式的时候统一处理 @@ -491,7 +524,7 @@ private Table processJoins(Table fromTable, List joins) { } - return mainTable; + return mainTables; } /** @@ -536,10 +569,13 @@ protected Expression builderExpression(Expression currentExpression, List
*/ protected Column getAliasColumn(Table table) { StringBuilder column = new StringBuilder(); + // 为了兼容隐式内连接,没有别名时条件就需要加上表名 if (table.getAlias() != null) { - column.append(table.getAlias().getName()).append(StringPool.DOT); + column.append(table.getAlias().getName()); + } else { + column.append(table.getName()); } - column.append(tenantLineHandler.getTenantIdColumn()); + column.append(StringPool.DOT).append(tenantLineHandler.getTenantIdColumn()); return new Column(column.toString()); } diff --git a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java index 91acb2f85f..56eea3f2be 100644 --- a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java +++ b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java @@ -38,48 +38,48 @@ void insert() { "INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, ?)"); // insert into select assertSql("insert into entity (id,name) select id,name from entity2", - "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM entity2 WHERE tenant_id = 1"); + "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM entity2 WHERE entity2.tenant_id = 1"); assertSql("insert into entity (id,name) select * from entity2", - "INSERT INTO entity (id, name, tenant_id) SELECT * FROM entity2 WHERE tenant_id = 1"); + "INSERT INTO entity (id, name, tenant_id) SELECT * FROM entity2 WHERE entity2.tenant_id = 1"); assertSql("insert into entity (id,name) select id,name from (select id,name from entity3) t", - "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t"); + "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM (SELECT id, name, tenant_id FROM entity3 WHERE entity3.tenant_id = 1) t"); assertSql("insert into entity (id,name) select * from (select id,name from entity3) t", - "INSERT INTO entity (id, name, tenant_id) SELECT * FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t"); + "INSERT INTO entity (id, name, tenant_id) SELECT * FROM (SELECT id, name, tenant_id FROM entity3 WHERE entity3.tenant_id = 1) t"); assertSql("insert into entity (id,name) select t.* from (select id,name from entity3) t", - "INSERT INTO entity (id, name, tenant_id) SELECT t.* FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t"); + "INSERT INTO entity (id, name, tenant_id) SELECT t.* FROM (SELECT id, name, tenant_id FROM entity3 WHERE entity3.tenant_id = 1) t"); } @Test void delete() { assertSql("delete from entity where id = ?", - "DELETE FROM entity WHERE tenant_id = 1 AND id = ?"); + "DELETE FROM entity WHERE entity.tenant_id = 1 AND id = ?"); } @Test void update() { assertSql("update entity set name = ? where id = ?", - "UPDATE entity SET name = ? WHERE tenant_id = 1 AND id = ?"); + "UPDATE entity SET name = ? WHERE entity.tenant_id = 1 AND id = ?"); } @Test void selectSingle() { // 单表 assertSql("select * from entity where id = ?", - "SELECT * FROM entity WHERE id = ? AND tenant_id = 1"); + "SELECT * FROM entity WHERE id = ? AND entity.tenant_id = 1"); assertSql("select * from entity where id = ? or name = ?", - "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1"); + "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1"); assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)", - "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1"); + "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1"); /* not */ assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)", - "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1"); + "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND entity.tenant_id = 1"); } @Test @@ -220,7 +220,7 @@ void selectRightJoin() { } @Test - void selectMixJoin(){ + void selectMixJoin() { assertSql("SELECT * FROM entity e " + "right join entity1 e1 on e1.id = e.id " + "left join entity2 e2 on e1.id = e2.id", @@ -247,23 +247,23 @@ void selectMixJoin(){ @Test - void selectJoinSubSelect(){ + void selectJoinSubSelect() { assertSql("select * from (select * from entity) e1 " + - "left join entity2 e2 on e1.id = e2.id", - "SELECT * FROM (SELECT * FROM entity WHERE tenant_id = 1) e1 " + + "left join entity2 e2 on e1.id = e2.id", + "SELECT * FROM (SELECT * FROM entity WHERE entity.tenant_id = 1) e1 " + "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1"); assertSql("select * from entity1 e1 " + "left join (select * from entity2) e2 " + "on e1.id = e2.id", "SELECT * FROM entity1 e1 " + - "LEFT JOIN (SELECT * FROM entity2 WHERE tenant_id = 1) e2 " + + "LEFT JOIN (SELECT * FROM entity2 WHERE entity2.tenant_id = 1) e2 " + "ON e1.id = e2.id " + "WHERE e1.tenant_id = 1"); } @Test - void selectSubJoin(){ + void selectSubJoin() { assertSql("select * FROM " + "(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)", @@ -356,19 +356,45 @@ void selectInnerJoin() { "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " + "WHERE (e.id = ? OR e.name = ?)"); - // 垃圾 inner join todo -// assertSql("SELECT * FROM entity,entity1 " + -// "WHERE entity.id = entity1.id", -// "SELECT * FROM entity e " + -// "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " + -// "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1"); + // 隐式内连接 + assertSql("SELECT * FROM entity,entity1 " + + "WHERE entity.id = entity1.id", + "SELECT * FROM entity, entity1 " + + "WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1"); + + // SubJoin with 隐式内连接 + assertSql("SELECT * FROM (entity,entity1) " + + "WHERE entity.id = entity1.id", + "SELECT * FROM (entity, entity1) " + + "WHERE entity.id = entity1.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1"); + + assertSql("SELECT * FROM ((entity,entity1),entity2) " + + "WHERE entity.id = entity1.id and entity.id = entity2.id", + "SELECT * FROM ((entity, entity1), entity2) " + + "WHERE entity.id = entity1.id AND entity.id = entity2.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1"); + + assertSql("SELECT * FROM (entity,(entity1,entity2)) " + + "WHERE entity.id = entity1.id and entity.id = entity2.id", + "SELECT * FROM (entity, (entity1, entity2)) " + + "WHERE entity.id = entity1.id AND entity.id = entity2.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1"); + + // 沙雕的括号写法 + assertSql("SELECT * FROM (((entity,entity1))) " + + "WHERE entity.id = entity1.id", + "SELECT * FROM (((entity, entity1))) " + + "WHERE entity.id = entity1.id " + + "AND entity.tenant_id = 1 AND entity1.tenant_id = 1"); + } @Test void selectWithAs() { assertSql("with with_as_A as (select * from entity) select * from with_as_A", - "WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A"); + "WITH with_as_A AS (SELECT * FROM entity WHERE entity.tenant_id = 1) SELECT * FROM with_as_A"); } void assertSql(String sql, String targetSql) {