Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix export sql error and the generated sql when modify web tabledata error #1448

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.time.LocalDateTime;
import java.util.List;

import ai.chat2db.spi.model.Header;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.SQLUtils.FormatOption;
Expand Down Expand Up @@ -103,9 +104,9 @@ public void export(@Valid @RequestBody DataExportRequest request, HttpServletRes

response.setCharacterEncoding("utf-8");
String fileName = URLEncoder.encode(
tableName + "_" + LocalDateTime.now().format(DatePattern.PURE_DATETIME_FORMATTER),
StandardCharsets.UTF_8)
.replaceAll("\\+", "%20");
tableName + "_" + LocalDateTime.now().format(DatePattern.PURE_DATETIME_FORMATTER),
StandardCharsets.UTF_8)
.replaceAll("\\+", "%20");

if (exportType == ExportTypeEnum.CSV) {
doExportCsv(sql, response, fileName);
Expand All @@ -115,19 +116,19 @@ public void export(@Valid @RequestBody DataExportRequest request, HttpServletRes
}

private void doExportCsv(String sql, HttpServletResponse response, String fileName)
throws IOException {
throws IOException {
response.setContentType("text/csv");
response.setHeader("Content-disposition", "attachment;filename*=utf-8''" + fileName + ".csv");

ExcelWrapper excelWrapper = new ExcelWrapper();
try {
ExcelWriterBuilder excelWriterBuilder = EasyExcel.write(response.getOutputStream())
.charset(StandardCharsets.UTF_8)
.excelType(ExcelTypeEnum.CSV);
.charset(StandardCharsets.UTF_8)
.excelType(ExcelTypeEnum.CSV);
excelWrapper.setExcelWriterBuilder(excelWriterBuilder);
SQLExecutor.getInstance().execute(Chat2DBContext.getConnection(), sql, headerList -> {
excelWriterBuilder.head(
EasyCollectionUtils.toList(headerList, header -> Lists.newArrayList(header.getName())));
EasyCollectionUtils.toList(headerList, header -> Lists.newArrayList(header.getName())));
excelWrapper.setExcelWriter(excelWriterBuilder.build());
excelWrapper.setWriteSheet(EasyExcel.writerSheet(0).build());
}, dataList -> {
Expand All @@ -143,29 +144,27 @@ private void doExportCsv(String sql, HttpServletResponse response, String fileNa
}

private void doExportInsert(String sql, HttpServletResponse response, String fileName, DbType dbType,
String tableName)
throws IOException {
String tableName)
throws IOException {
response.setContentType("text/sql");
response.setHeader("Content-disposition", "attachment;filename*=utf-8''" + fileName + ".sql");

try (PrintWriter printWriter = response.getWriter()) {
InsertWrapper insertWrapper = new InsertWrapper();
SQLExecutor.getInstance().execute(Chat2DBContext.getConnection(), sql,
headerList -> insertWrapper.setHeaderList(
EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName())))
, dataList -> {
SQLInsertStatement sqlInsertStatement = new SQLInsertStatement();
sqlInsertStatement.setDbType(dbType);
sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName));
sqlInsertStatement.getColumns().addAll(insertWrapper.getHeaderList());
ValuesClause valuesClause = new ValuesClause();
for (String s : dataList) {
valuesClause.addValue(s);
}
sqlInsertStatement.setValues(valuesClause);

printWriter.println(SQLUtils.toSQLString(sqlInsertStatement, dbType, INSERT_FORMAT_OPTION) + ";");
}, false);
headerList -> insertWrapper.setHeaderList(headerList)
, dataList -> {
SQLInsertStatement sqlInsertStatement = new SQLInsertStatement();
sqlInsertStatement.setDbType(dbType);
sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName));
List<Header> headerList = insertWrapper.getHeaderList();
sqlInsertStatement.getColumns().addAll(EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName())));
ValuesClause valuesClause = SqlUtils.getValuesClause(dataList, headerList);
sqlInsertStatement.setValues(valuesClause);

printWriter.println(SQLUtils.toSQLString(sqlInsertStatement, dbType, INSERT_FORMAT_OPTION) + ";");

}, false);
}
}

Expand All @@ -174,7 +173,7 @@ private void doExportInsert(String sql, HttpServletResponse response, String fil
@NoArgsConstructor
@AllArgsConstructor
public static class InsertWrapper {
private List<SQLIdentifierExpr> headerList;
private List<Header> headerList;
}

@Data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.chat2db.server.web.api.controller.rdb.factory.ExportServiceFactory;
import ai.chat2db.server.web.api.controller.rdb.request.DataExportRequest;
import ai.chat2db.server.web.api.controller.rdb.vo.TableVO;
import ai.chat2db.spi.model.Header;
import ai.chat2db.spi.model.Table;
import ai.chat2db.spi.sql.Chat2DBContext;
import ai.chat2db.spi.sql.ConnectInfo;
Expand Down Expand Up @@ -273,20 +274,18 @@ private void doExportInsert(String sql, File file, DbType dbType,
try (PrintWriter printWriter = new PrintWriter(file, StandardCharsets.UTF_8.name())) {
RdbDmlExportController.InsertWrapper insertWrapper = new RdbDmlExportController.InsertWrapper();
SQLExecutor.getInstance().execute(Chat2DBContext.getConnection(), sql,
headerList -> insertWrapper.setHeaderList(
EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName())))
headerList -> insertWrapper.setHeaderList(headerList)
, dataList -> {
SQLInsertStatement sqlInsertStatement = new SQLInsertStatement();
sqlInsertStatement.setDbType(dbType);
sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName));
sqlInsertStatement.getColumns().addAll(insertWrapper.getHeaderList());
SQLInsertStatement.ValuesClause valuesClause = new SQLInsertStatement.ValuesClause();
for (String s : dataList) {
valuesClause.addValue(s);
}
List<Header> headerList = insertWrapper.getHeaderList();
sqlInsertStatement.getColumns().addAll(EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(header.getName())));
SQLInsertStatement.ValuesClause valuesClause = SqlUtils.getValuesClause(dataList, headerList);
sqlInsertStatement.setValues(valuesClause);

printWriter.println(SQLUtils.toSQLString(sqlInsertStatement, dbType, INSERT_FORMAT_OPTION) + ";");

}, false);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
package ai.chat2db.spi.jdbc;

import ai.chat2db.server.tools.common.util.EasyCollectionUtils;
import ai.chat2db.spi.MetaData;
import ai.chat2db.spi.SqlBuilder;
import ai.chat2db.spi.enums.DmlType;
import ai.chat2db.spi.model.*;
import ai.chat2db.spi.sql.Chat2DBContext;
import ai.chat2db.spi.util.JdbcUtils;
import ai.chat2db.spi.util.SqlUtils;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.google.common.collect.Lists;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
Expand Down Expand Up @@ -49,6 +57,7 @@ public String pageLimit(String sql, int offset, int pageNo, int pageSize) {
}

public static String CREATE_DATABASE_SQL = "CREATE DATABASE IF NOT EXISTS `%s` DEFAULT CHARACTER SET %s COLLATE %s";
private static final SQLUtils.FormatOption FORMAT_OPTION = new SQLUtils.FormatOption(true, false);

@Override
public String buildCreateDatabaseSql(Database database) {
Expand Down Expand Up @@ -116,14 +125,14 @@ public String buildSqlByQuery(QueryResult queryResult) {
String sql = "";
if ("UPDATE".equalsIgnoreCase(operation.getType())) {
sql = getUpdateSql(tableName, headerList, row, odlRow, metaSchema, keyColumns, false);
if("MYSQL".equalsIgnoreCase(dbType)){
if ("MYSQL".equalsIgnoreCase(dbType)) {
sql = sql + " LIMIT 1";
}
} else if ("CREATE".equalsIgnoreCase(operation.getType())) {
sql = getInsertSql(tableName, headerList, row, metaSchema);
} else if ("DELETE".equalsIgnoreCase(operation.getType())) {
sql = getDeleteSql(tableName, headerList, odlRow, metaSchema, keyColumns);
if("MYSQL".equalsIgnoreCase(dbType)){
if ("MYSQL".equalsIgnoreCase(dbType)) {
sql = sql + " LIMIT 1";
}
} else if ("UPDATE_COPY".equalsIgnoreCase(operation.getType())) {
Expand Down Expand Up @@ -310,10 +319,12 @@ private List<String> getPrimaryColumns(List<Header> headerList) {

private String getDeleteSql(String tableName, List<Header> headerList, List<String> row, MetaData metaSchema,
List<String> keyColumns) {
StringBuilder script = new StringBuilder();
script.append("DELETE FROM ").append(tableName).append("");
script.append(buildWhere(headerList, row, metaSchema, keyColumns));
return script.toString();
SQLDeleteStatement sqlDeleteStatement = new SQLDeleteStatement();
sqlDeleteStatement.setTableSource(new SQLExprTableSource(tableName));
sqlDeleteStatement.setWhere(buildWhereExpr(headerList, row, metaSchema, keyColumns));
DbType dbType = JdbcUtils.parse2DruidDbType(Chat2DBContext.getConnectInfo().getDbType());
String deleteSql = SQLUtils.toSQLString(sqlDeleteStatement, dbType, FORMAT_OPTION);
return deleteSql;
}

private String buildWhere(List<Header> headerList, List<String> row, MetaData metaSchema, List<String> keyColumns) {
Expand Down Expand Up @@ -357,34 +368,55 @@ private String buildWhere(List<Header> headerList, List<String> row, MetaData me
return script.toString();
}

private SQLExpr buildWhereExpr(List<Header> headerList, List<String> row, MetaData metaSchema, List<String> keyColumns) {
List<SQLBinaryOpExpr> conditions = new ArrayList<>();

if (CollectionUtils.isEmpty(keyColumns)) {
for (int i = 1; i < row.size(); i++) {
String oldValue = row.get(i);
Header header = headerList.get(i);
if (oldValue == null) {
conditions.add(SQLBinaryOpExpr.isNull(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName()))));
} else {
conditions.add(SQLBinaryOpExpr.eq(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())),
SqlUtils.getSqlExpr(oldValue, header.getDataType())));
}
}
} else {
for (int i = 1; i < row.size(); i++) {
String oldValue = row.get(i);
Header header = headerList.get(i);
String columnName = header.getName();
if (keyColumns.contains(columnName)) {
if (oldValue == null) {
conditions.add(SQLBinaryOpExpr.isNull(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName()))));
} else {
conditions.add(SQLBinaryOpExpr.eq(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())),
SqlUtils.getSqlExpr(oldValue, header.getDataType())));
}
}
}
}
SQLExpr expr = null;
for (SQLBinaryOpExpr condition : conditions) {
expr = SQLBinaryOpExpr.and(expr, condition);
}

return expr;
}

private String getInsertSql(String tableName, List<Header> headerList, List<String> row, MetaData metaSchema) {
if (CollectionUtils.isEmpty(row) || ObjectUtils.allNull(row.toArray())) {
return "";
}
StringBuilder script = new StringBuilder();
script.append("INSERT INTO ").append(tableName)
.append(" (");
for (int i = 1; i < row.size(); i++) {
Header header = headerList.get(i);
//String newValue = row.get(i);
//if (newValue != null) {
script.append(metaSchema.getMetaDataName(header.getName()))
.append(",");
// }
}
script.deleteCharAt(script.length() - 1);
script.append(") VALUES (");
for (int i = 1; i < row.size(); i++) {
String newValue = row.get(i);
//if (newValue != null) {
Header header = headerList.get(i);
script.append(SqlUtils.getSqlValue(newValue, header.getDataType()))
.append(",");
//}
}
script.deleteCharAt(script.length() - 1);
script.append(")");
return script.toString();
DbType dbType = JdbcUtils.parse2DruidDbType(Chat2DBContext.getConnectInfo().getDbType());
SQLInsertStatement sqlInsertStatement = new SQLInsertStatement();
sqlInsertStatement.setDbType(dbType);
sqlInsertStatement.setTableSource(new SQLExprTableSource(tableName));
sqlInsertStatement.getColumns().addAll(EasyCollectionUtils.toList(headerList, header -> new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName()))));
SQLInsertStatement.ValuesClause valuesClause = SqlUtils.getValuesClause(row, headerList);
sqlInsertStatement.setValues(valuesClause);
return SQLUtils.toSQLString(sqlInsertStatement, dbType, FORMAT_OPTION);

}

Expand All @@ -395,6 +427,10 @@ private String getUpdateSql(String tableName, List<Header> headerList, List<Stri
if (CollectionUtils.isEmpty(row) || CollectionUtils.isEmpty(odlRow)) {
return "";
}
DbType dbType = JdbcUtils.parse2DruidDbType(Chat2DBContext.getConnectInfo().getDbType());
SQLUpdateStatement sqlUpdateStatement = new SQLUpdateStatement();
sqlUpdateStatement.setDbType(dbType);
sqlUpdateStatement.setTableSource(new SQLExprTableSource(tableName));
script.append("UPDATE ").append(tableName).append(" set ");
for (int i = 1; i < row.size(); i++) {
String newValue = row.get(i);
Expand All @@ -403,14 +439,14 @@ private String getUpdateSql(String tableName, List<Header> headerList, List<Stri
continue;
}
Header header = headerList.get(i);
String newSqlValue = SqlUtils.getSqlValue(newValue, header.getDataType());
script.append(metaSchema.getMetaDataName(header.getName()))
.append(" = ")
.append(newSqlValue)
.append(",");
SQLUpdateSetItem sqlUpdateSetItem = new SQLUpdateSetItem();
sqlUpdateSetItem.setColumn(new SQLIdentifierExpr(metaSchema.getMetaDataName(header.getName())));
sqlUpdateSetItem.setValue(SqlUtils.getSqlExpr(newValue, header.getDataType()));
sqlUpdateStatement.addItem(sqlUpdateSetItem);
}
script.deleteCharAt(script.length() - 1);
script.append(buildWhere(headerList, odlRow, metaSchema, keyColumns));
return script.toString();
SQLExpr sqlExpr = buildWhereExpr(headerList, odlRow, metaSchema, keyColumns);
sqlUpdateStatement.addWhere(sqlExpr);

return SQLUtils.toSQLString(sqlUpdateStatement, dbType, FORMAT_OPTION);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import ai.chat2db.server.tools.base.excption.BusinessException;
import ai.chat2db.spi.enums.DataTypeEnum;
import ai.chat2db.spi.model.ExecuteResult;
import ai.chat2db.spi.model.Header;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLTableSource;
import com.alibaba.druid.sql.ast.expr.SQLCharExpr;
import com.alibaba.druid.sql.ast.expr.SQLNullExpr;
import com.alibaba.druid.sql.ast.expr.SQLTimestampExpr;
import com.alibaba.druid.sql.ast.expr.SQLValuableExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.google.common.collect.Lists;
import com.oceanbase.tools.sqlparser.oracle.PlSqlLexer;
Expand Down Expand Up @@ -112,6 +114,28 @@ public static String getTableName(String sql, DbType dbType) {
return sqlExprTableSource.getTableName();
}

public static SQLInsertStatement.ValuesClause getValuesClause(List<String> row, List<Header> headerList) {
SQLInsertStatement.ValuesClause valuesClause = new SQLInsertStatement.ValuesClause();
for (int i = 0; i < row.size(); i++) {
String s = row.get(i);
Header header = headerList.get(i);
valuesClause.addValue(getSqlExpr(s, header.getDataType()));
}
return valuesClause;
}

public static SQLValuableExpr getSqlExpr(String value, String dataType) {
if (value == null) {
return new SQLNullExpr();
} else if (DataTypeEnum.getByCode(dataType).equals(DataTypeEnum.DATETIME)) {
return new SQLTimestampExpr(value);
} else {
return new SQLCharExpr(value);
}


}

private static SQLTableSource getSQLExprTableSource(SQLTableSource sqlTableSource) {
if (sqlTableSource instanceof SQLExprTableSource sqlExprTableSource) {
return sqlExprTableSource;
Expand Down Expand Up @@ -140,8 +164,8 @@ public static List<String> parse(String sql, DbType dbType) {
List<SplitSqlString> sqls = sqlSplitter.split(sql);
return sqls.stream().map(splitSqlString -> SQLParserUtils.removeComment(splitSqlString.getStr(), dbType)).collect(Collectors.toList());
}
}catch (Exception e){
log.error("sqlSplitter error",e);
} catch (Exception e) {
log.error("sqlSplitter error", e);
}
try {
if (DbType.mysql.equals(dbType) ||
Expand All @@ -152,8 +176,8 @@ public static List<String> parse(String sql, DbType dbType) {
sqlSplitProcessor.setDelimiter(";");
return split(sqlSplitProcessor, sql, dbType);
}
}catch (Exception e){
log.error("sqlSplitProcessor error",e);
} catch (Exception e) {
log.error("sqlSplitProcessor error", e);
}
// sql = removeDelimiter(sql);
if (StringUtils.isBlank(sql)) {
Expand Down Expand Up @@ -246,7 +270,7 @@ public static String getSqlValue(String value, String dataType) {
if (value == null) {
return null;
}
if("".equals(value)){
if ("".equals(value)) {
return "''";
}
if (DEFAULT_VALUE.equals(value)) {
Expand Down