Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 fix gitee I4FP6E, right join bug #4035

Merged
merged 3 commits into from
Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@

import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -234,27 +237,55 @@ protected void appendSelectItem(List<SelectItem> selectItems) {
* 处理 PlainSelect
*/
protected void processPlainSelect(PlainSelect plainSelect) {
FromItem fromItem = plainSelect.getFromItem();
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
//#1186 github
plainSelect.setWhere(builderExpression(where, fromTable));
}
} else {
processFromItem(fromItem);
}
//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}

// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);

// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);

// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
processJoins(joins);
mainTables = processJoins(mainTables, joins);
}

// 当有 mainTable 时,进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}

private List<Table> processFromItem(FromItem fromItem) {
// 处理括号括起来的表达式
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}

List<Table> mainTables = new ArrayList<>();
// 无 join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
mainTables.add(fromTable);
}
} else if (fromItem instanceof SubJoin) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((SubJoin) fromItem);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
}

/**
Expand Down Expand Up @@ -282,7 +313,7 @@ protected void processWhereSubSelect(Expression where) {
return;
}
if (where instanceof FromItem) {
processFromItem((FromItem) where);
processOtherFromItem((FromItem) where);
return;
}
if (where.toString().indexOf("SELECT") > 0) {
Expand Down Expand Up @@ -348,16 +379,13 @@ protected void processFunction(Function function) {
/**
* 处理子查询等
*/
protected void processFromItem(FromItem fromItem) {
if (fromItem instanceof SubJoin) {
SubJoin subJoin = (SubJoin) fromItem;
if (subJoin.getJoinList() != null) {
processJoins(subJoin.getJoinList());
}
if (subJoin.getLeft() != null) {
processFromItem(subJoin.getLeft());
}
} else if (fromItem instanceof SubSelect) {
protected void processOtherFromItem(FromItem fromItem) {
// 去除括号
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}

if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
Expand All @@ -375,82 +403,160 @@ protected void processFromItem(FromItem fromItem) {
}
}

/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(SubJoin subJoin) {
List<Table> mainTables = new ArrayList<>();
if (subJoin.getJoinList() != null) {
List<Table> list = processFromItem(subJoin.getLeft());
mainTables.addAll(list);
mainTables = processJoins(mainTables, subJoin.getJoinList());
}
return mainTables;
}

/**
* 处理 joins
*
* @param joins join 集合
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private void processJoins(List<Join> joins) {
private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
if (mainTables == null) {
mainTables = new ArrayList<>();
}

// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}

//对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
Deque<Table> tables = new LinkedList<>();
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem fromItem = join.getRightItem();
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
FromItem joinItem = join.getRightItem();

// 获取当前 join 的表,subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof SubJoin) {
joinTables = processSubJoin((SubJoin) joinItem);
}

if (joinTables != null) {

// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}

// 当前表是否忽略
Table joinTable = joinTables.get(0);
boolean joinTableNeedIgnore = tenantLineHandler.ignoreTable(joinTable.getName());

List<Table> onTables = null;
// 如果不要忽略,且是右连接,则记录下当前表
if (join.isRight()) {
mainTable = joinTableNeedIgnore ? null : joinTable;
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isLeft()) {
if (!joinTableNeedIgnore) {
onTables = Collections.singletonList(joinTable);
}
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
}
mainTables = new ArrayList<>();
if (mainTable != null) {
mainTables.add(mainTable);
}

// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1) {
processJoin(join);
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
join.setOnExpressions(onExpressions);
leftTable = joinTable;
continue;
}
// 当前表是否忽略
boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName());
// 表名压栈,忽略的表压入 null,以便后续不处理
tables.push(needIgnore ? null : fromTable);
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
Table currentTable = tables.poll();
if (currentTable == null) {
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTable));
onExpressions.add(builderExpression(originOnExpression, currentTableList));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
} else {
// 处理右边连接的子表达式
processFromItem(fromItem);
processOtherFromItem(joinItem);
leftTable = null;
}

}

return mainTables;
}

/**
* 处理联接语句
* 处理条件
*/
protected void processJoin(Join join) {
if (join.getRightItem() instanceof Table) {
Table fromTable = (Table) join.getRightItem();
if (tenantLineHandler.ignoreTable(fromTable.getName())) {
// 过滤退出执行
return;
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 租户
Expression tenantId = tenantLineHandler.getTenantId();
// 构造每张表的条件
List<EqualsTo> equalsTos = tables.stream()
.map(item -> new EqualsTo(getAliasColumn(item), tenantId))
.collect(Collectors.toList());
// 注入的表达式
Expression injectExpression = equalsTos.get(0);
// 如果有多表,则用 and 连接
if (equalsTos.size() > 1) {
for (int i = 1; i < equalsTos.size(); i++) {
injectExpression = new AndExpression(injectExpression, equalsTos.get(i));
}
// 走到这里说明 on 表达式肯定只有一个
Collection<Expression> originOnExpressions = join.getOnExpressions();
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable));
join.setOnExpressions(onExpressions);
}
}

/**
* 处理条件
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(this.getAliasColumn(table));
equalsTo.setRightExpression(tenantLineHandler.getTenantId());
if (currentExpression == null) {
return equalsTo;
return injectExpression;
}
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), equalsTo);
return new AndExpression(new Parenthesis(currentExpression), injectExpression);
} else {
return new AndExpression(currentExpression, equalsTo);
return new AndExpression(currentExpression, injectExpression);
}
}

Expand All @@ -463,10 +569,13 @@ protected Expression builderExpression(Expression currentExpression, Table table
*/
protected Column getAliasColumn(Table table) {
StringBuilder column = new StringBuilder();
// 为了兼容隐式内连接,没有别名时条件就需要加上表名
if (table.getAlias() != null) {
column.append(table.getAlias().getName()).append(StringPool.DOT);
column.append(table.getAlias().getName());
} else {
column.append(table.getName());
}
column.append(tenantLineHandler.getTenantIdColumn());
column.append(StringPool.DOT).append(tenantLineHandler.getTenantIdColumn());
return new Column(column.toString());
}

Expand Down
Loading