Skip to content

Commit

Permalink
[Improve][Connector-v2] Use regex to match filedName placeholders in …
Browse files Browse the repository at this point in the history
…jdbc sink (#8222)
  • Loading branch information
dailai authored Dec 11, 2024
1 parent 2e24941 commit c02d4fe
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.seatunnel.connectors.seatunnel.jdbc.internal.executor;

import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

Expand Down Expand Up @@ -47,6 +49,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.apache.seatunnel.shade.com.google.common.base.Preconditions.checkArgument;
import static org.apache.seatunnel.shade.com.google.common.base.Preconditions.checkNotNull;
Expand Down Expand Up @@ -669,29 +673,26 @@ public static FieldNamedPreparedStatement prepareStatement(
connection.prepareStatement(parsedSQL), indexMapping);
}

private static String parseNamedStatement(String sql, Map<String, List<Integer>> paramMap) {
StringBuilder parsedSql = new StringBuilder();
int fieldIndex = 1; // SQL statement parameter index starts from 1
int length = sql.length();
for (int i = 0; i < length; i++) {
char c = sql.charAt(i);
if (':' == c) {
int j = i + 1;
while (j < length && Character.isJavaIdentifierPart(sql.charAt(j))) {
j++;
}
String parameterName = sql.substring(i + 1, j);
checkArgument(
!parameterName.isEmpty(),
"Named parameters in SQL statement must not be empty.");
paramMap.computeIfAbsent(parameterName, n -> new ArrayList<>()).add(fieldIndex);
fieldIndex++;
i = j - 1;
parsedSql.append('?');
} else {
parsedSql.append(c);
}
@VisibleForTesting
public static String parseNamedStatement(String sql, Map<String, List<Integer>> paramMap) {
Pattern pattern =
Pattern.compile(":([\\p{L}\\p{Nl}\\p{Nd}\\p{Pc}\\$\\-\\.@%&*#~!?^+=<>|]+)");
Matcher matcher = pattern.matcher(sql);

StringBuffer result = new StringBuffer();
int fieldIndex = 1;

while (matcher.find()) {
String parameterName = matcher.group(1);
checkArgument(
!parameterName.isEmpty(),
"Named parameters in SQL statement must not be empty.");
paramMap.computeIfAbsent(parameterName, n -> new ArrayList<>()).add(fieldIndex++);
matcher.appendReplacement(result, "?");
}
return parsedSql.toString();

matcher.appendTail(result);

return result.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.connectors.seatunnel.jdbc.internal.executor;

import org.junit.jupiter.api.Test;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class FieldNamedPreparedStatementTest {

private static final String[] SPECIAL_FILEDNAMES =
new String[] {
"USER@TOKEN",
"字段%名称",
"field_name",
"field.name",
"field-name",
"$fieldName",
"field&key",
"field*value",
"field#1",
"field~test",
"field!data",
"field?question",
"field^caret",
"field+add",
"field=value",
"fieldmax",
"field|pipe"
};

@Test
public void testParseNamedStatementWithSpecialCharacters() {
String sql =
"INSERT INTO `nhp_emr_ws`.`cm_prescriptiondetails_cs` (`USER@TOKEN`, `字段%名称`, `field_name`, `field.name`, `field-name`, `$fieldName`, `field&key`, `field*value`, `field#1`, `field~test`, `field!data`, `field?question`, `field^caret`, `field+add`, `field=value`, `fieldmax`, `field|pipe`) VALUES (:USER@TOKEN, :字段%名称, :field_name, :field.name, :field-name, :$fieldName, :field&key, :field*value, :field#1, :field~test, :field!data, :field?question, :field^caret, :field+add, :field=value, :fieldmax, :field|pipe) ON DUPLICATE KEY UPDATE `USER@TOKEN`=VALUES(`USER@TOKEN`), `字段%名称`=VALUES(`字段%名称`), `field_name`=VALUES(`field_name`), `field.name`=VALUES(`field.name`), `field-name`=VALUES(`field-name`), `$fieldName`=VALUES(`$fieldName`), `field&key`=VALUES(`field&key`), `field*value`=VALUES(`field*value`), `field#1`=VALUES(`field#1`), `field~test`=VALUES(`field~test`), `field!data`=VALUES(`field!data`), `field?question`=VALUES(`field?question`), `field^caret`=VALUES(`field^caret`), `field+add`=VALUES(`field+add`), `field=value`=VALUES(`field=value`), `fieldmax`=VALUES(`fieldmax`), `field|pipe`=VALUES(`field|pipe`)";

String exceptPreparedstatement =
"INSERT INTO `nhp_emr_ws`.`cm_prescriptiondetails_cs` (`USER@TOKEN`, `字段%名称`, `field_name`, `field.name`, `field-name`, `$fieldName`, `field&key`, `field*value`, `field#1`, `field~test`, `field!data`, `field?question`, `field^caret`, `field+add`, `field=value`, `fieldmax`, `field|pipe`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE `USER@TOKEN`=VALUES(`USER@TOKEN`), `字段%名称`=VALUES(`字段%名称`), `field_name`=VALUES(`field_name`), `field.name`=VALUES(`field.name`), `field-name`=VALUES(`field-name`), `$fieldName`=VALUES(`$fieldName`), `field&key`=VALUES(`field&key`), `field*value`=VALUES(`field*value`), `field#1`=VALUES(`field#1`), `field~test`=VALUES(`field~test`), `field!data`=VALUES(`field!data`), `field?question`=VALUES(`field?question`), `field^caret`=VALUES(`field^caret`), `field+add`=VALUES(`field+add`), `field=value`=VALUES(`field=value`), `fieldmax`=VALUES(`fieldmax`), `field|pipe`=VALUES(`field|pipe`)";

Map<String, List<Integer>> paramMap = new HashMap<>();
String actualSQL = FieldNamedPreparedStatement.parseNamedStatement(sql, paramMap);
assertEquals(exceptPreparedstatement, actualSQL);
for (int i = 0; i < SPECIAL_FILEDNAMES.length; i++) {
assertTrue(paramMap.containsKey(SPECIAL_FILEDNAMES[i]));
assertEquals(i + 1, paramMap.get(SPECIAL_FILEDNAMES[i]).get(0));
}
}

@Test
public void testParseNamedStatement() {
String sql = "UPDATE table SET col1 = :param1, col2 = :param1 WHERE col3 = :param2";
Map<String, List<Integer>> paramMap = new HashMap<>();
String expectedSQL = "UPDATE table SET col1 = ?, col2 = ? WHERE col3 = ?";

String actualSQL = FieldNamedPreparedStatement.parseNamedStatement(sql, paramMap);

assertEquals(expectedSQL, actualSQL);
assertTrue(paramMap.containsKey("param1"));
assertTrue(paramMap.containsKey("param2"));
assertEquals(1, paramMap.get("param1").get(0).intValue());
assertEquals(2, paramMap.get("param1").get(1).intValue());
assertEquals(3, paramMap.get("param2").get(0).intValue());
}

@Test
public void testParseNamedStatementWithNoNamedParameters() {
String sql = "SELECT * FROM table";
Map<String, List<Integer>> paramMap = new HashMap<>();
String expectedSQL = "SELECT * FROM table";

String actualSQL = FieldNamedPreparedStatement.parseNamedStatement(sql, paramMap);

assertEquals(expectedSQL, actualSQL);
assertTrue(paramMap.isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public class JdbcMysqlIT extends AbstractJdbcIT {
private static final String CREATE_SQL =
"CREATE TABLE IF NOT EXISTS %s\n"
+ "(\n"
+ " `c_bit_1` bit(1) DEFAULT NULL,\n"
+ " `c-bit_1` bit(1) DEFAULT NULL,\n"
+ " `c_bit_8` bit(8) DEFAULT NULL,\n"
+ " `c_bit_16` bit(16) DEFAULT NULL,\n"
+ " `c_bit_32` bit(32) DEFAULT NULL,\n"
Expand Down Expand Up @@ -191,7 +191,7 @@ protected void checkResult(
String executeKey, TestContainer container, Container.ExecResult execResult) {
String[] fieldNames =
new String[] {
"c_bit_1",
"c-bit_1",
"c_bit_8",
"c_bit_16",
"c_bit_32",
Expand Down Expand Up @@ -249,7 +249,7 @@ String driverUrl() {
Pair<String[], List<SeaTunnelRow>> initTestData() {
String[] fieldNames =
new String[] {
"c_bit_1",
"c-bit_1",
"c_bit_8",
"c_bit_16",
"c_bit_32",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ sink {
user = "root"
password = "Abc!@#135_seatunnel"

query = """insert into sink (c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
query = """insert into sink (`c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ CREATE TABLE sink_table WITH (


INSERT INTO sink_table
SELECT c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
SELECT `c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ sink {
user = "root"
password = "Abc!@#135_seatunnel"
connection_check_timeout_sec = 100
query = """insert into sink (c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
query = """insert into sink (`c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ CREATE TABLE sink_table WITH (


CREATE TABLE temp1 AS
SELECT c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
SELECT `c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand All @@ -58,4 +58,4 @@ CREATE TABLE temp1 AS


INSERT INTO sink_table SELECT * FROM temp1;

Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ sink {
user = "root"
password = "Abc!@#135_seatunnel"
connection_check_timeout_sec = 100
query = """insert into sink (c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
query = """insert into sink (`c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public SeaTunnelRowType typeMapping(List<String> inputColumnsMapping) {
for (SelectItem selectItem : selectItems) {
if (selectItem.getExpression() instanceof AllColumns) {
for (int i = 0; i < inputRowType.getFieldNames().length; i++) {
fieldNames[idx] = inputRowType.getFieldName(i);
fieldNames[idx] = cleanEscape(inputRowType.getFieldName(i));
seaTunnelDataTypes[idx] = inputRowType.getFieldType(i);
if (inputColumnsMapping != null) {
inputColumnsMapping.set(idx, inputRowType.getFieldName(i));
Expand All @@ -194,16 +194,12 @@ public SeaTunnelRowType typeMapping(List<String> inputColumnsMapping) {
Expression expression = selectItem.getExpression();
if (selectItem.getAlias() != null) {
String aliasName = selectItem.getAlias().getName();
if (aliasName.startsWith(ESCAPE_IDENTIFIER)
&& aliasName.endsWith(ESCAPE_IDENTIFIER)) {
aliasName = aliasName.substring(1, aliasName.length() - 1);
}
fieldNames[idx] = aliasName;
fieldNames[idx] = cleanEscape(aliasName);
} else {
if (expression instanceof Column) {
fieldNames[idx] = ((Column) expression).getColumnName();
fieldNames[idx] = cleanEscape(((Column) expression).getColumnName());
} else {
fieldNames[idx] = expression.toString();
fieldNames[idx] = cleanEscape(expression.toString());
}
}

Expand All @@ -225,6 +221,13 @@ public SeaTunnelRowType typeMapping(List<String> inputColumnsMapping) {
fieldNames, seaTunnelDataTypes, lateralViews, inputColumnsMapping);
}

private static String cleanEscape(String columnName) {
if (columnName.startsWith(ESCAPE_IDENTIFIER) && columnName.endsWith(ESCAPE_IDENTIFIER)) {
columnName = columnName.substring(1, columnName.length() - 1);
}
return columnName;
}

@Override
public List<SeaTunnelRow> transformBySQL(SeaTunnelRow inputRow, SeaTunnelRowType outRowType) {
// ------Physical Query Plan Execution------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,13 @@ public void testEscapeIdentifier() {
ReadonlyConfig.fromMap(
Collections.singletonMap(
"query",
"select id, trim(`apply`) as `apply` from test where `apply` = 'a'"));
"select `id`, trim(`apply`) as `apply` from test where `apply` = 'a'"));
SQLTransform sqlTransform = new SQLTransform(config, table);
TableSchema tableSchema = sqlTransform.transformTableSchema();
List<SeaTunnelRow> result =
sqlTransform.transformRow(
new SeaTunnelRow(new Object[] {Integer.valueOf(1), String.valueOf("a")}));
Assertions.assertEquals("id", tableSchema.getFieldNames()[0]);
Assertions.assertEquals("apply", tableSchema.getFieldNames()[1]);
Assertions.assertEquals("a", result.get(0).getField(1));
result =
Expand Down

0 comments on commit c02d4fe

Please sign in to comment.