Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][Transform][Spark] Remove unnecessary row conversion #4335

Merged
merged 7 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
import org.apache.seatunnel.core.starter.exception.TaskExecuteException;
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelTransformPluginDiscovery;
import org.apache.seatunnel.translation.spark.serialization.InternalRowConverter;
import org.apache.seatunnel.translation.spark.utils.TypeConverterUtils;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.StructType;

Expand Down Expand Up @@ -125,21 +123,14 @@ private Dataset<Row> sparkTransform(SeaTunnelTransform transform, Dataset<Row> s
SeaTunnelRow seaTunnelRow;
List<Row> outputRows = new ArrayList<>();
Iterator<Row> rowIterator = stream.toLocalIterator();
InternalRowConverter inputRowConverter = new InternalRowConverter(seaTunnelDataType);
InternalRowConverter outputRowConverter =
new InternalRowConverter(transform.getProducedType());
while (rowIterator.hasNext()) {
Row row = rowIterator.next();
seaTunnelRow = inputRowConverter.reconvert(InternalRow.apply(row.toSeq()));
seaTunnelRow = new SeaTunnelRow(((GenericRowWithSchema) row).values());
seaTunnelRow = (SeaTunnelRow) transform.map(seaTunnelRow);
if (seaTunnelRow == null) {
continue;
}
InternalRow internalRow = outputRowConverter.convert(seaTunnelRow);

Object[] fields = outputRowConverter.convertDateTime(internalRow, structType);

outputRows.add(new GenericRowWithSchema(fields, structType));
outputRows.add(new GenericRowWithSchema(seaTunnelRow.getFields(), structType));
}
return sparkRuntimeEnvironment.getSparkSession().createDataFrame(outputRows, structType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
import org.apache.seatunnel.core.starter.exception.TaskExecuteException;
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelTransformPluginDiscovery;
import org.apache.seatunnel.translation.spark.serialization.InternalRowConverter;
import org.apache.seatunnel.translation.spark.utils.TypeConverterUtils;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.StructType;

Expand Down Expand Up @@ -125,22 +123,16 @@ private Dataset<Row> sparkTransform(SeaTunnelTransform transform, Dataset<Row> s
SeaTunnelRow seaTunnelRow;
List<Row> outputRows = new ArrayList<>();
Iterator<Row> rowIterator = stream.toLocalIterator();
InternalRowConverter inputRowConverter = new InternalRowConverter(seaTunnelDataType);
InternalRowConverter outputRowConverter =
new InternalRowConverter(transform.getProducedType());
while (rowIterator.hasNext()) {
Row row = rowIterator.next();
seaTunnelRow = inputRowConverter.reconvert(InternalRow.apply(row.toSeq()));
seaTunnelRow = new SeaTunnelRow(((GenericRowWithSchema) row).values());
seaTunnelRow = (SeaTunnelRow) transform.map(seaTunnelRow);
if (seaTunnelRow == null) {
continue;
}
InternalRow internalRow = outputRowConverter.convert(seaTunnelRow);

Object[] fields = outputRowConverter.convertDateTime(internalRow, structType);

outputRows.add(new GenericRowWithSchema(fields, structType));
outputRows.add(new GenericRowWithSchema(seaTunnelRow.getFields(), structType));
}

return sparkRuntimeEnvironment.getSparkSession().createDataFrame(outputRows, structType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ source {
id = "int"
name = "string"
age = "int"
c_map = "map<string, string>"
c_array = "array<int>"
c_timestamp = "timestamp"
c_date = "date"
}
}
}
Expand All @@ -41,7 +45,8 @@ transform {
source_table_name = "fake"
result_table_name = "fake1"
# the query table name must same as field 'source_table_name'
query = "select id, regexp_replace(name, '.+', 'b') as name, age+1 as age, pi() as pi from fake"
query = """select id, regexp_replace(name, '.+', 'b') as name, age+1 as age, pi() as pi,
c_map, c_array, c_timestamp, c_date from fake"""
}
# The SQL transform support base function and criteria operation
# But the complex SQL unsupported yet, include: multi source table/rows JOIN and AGGREGATE operation and the like
Expand Down Expand Up @@ -105,6 +110,24 @@ sink {
rule_type = NOT_NULL
}
]
},
{
field_name = c_timestamp
field_type = timestamp
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = c_date
field_type = date
field_value = [
{
rule_type = NOT_NULL
}
]
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ transform {

# you can also use other transform plugins, such as sql
sql {
sql = "select c_map,c_array,c_string,c_boolean,c_tinyint,c_smallint,c_int,c_bigint,c_float,c_double,c_null,c_bytes,c_date,c_timestamp from fake"
source_table_name = "fake"
query = "select c_map,c_array,c_string,c_boolean,c_tinyint,c_smallint,c_int,c_bigint,c_float,c_double,c_null,c_bytes,c_date,c_timestamp from fake"
result_table_name = "sql"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;

import java.io.IOException;
Expand Down Expand Up @@ -244,23 +241,4 @@ private static Object reconvertArray(ArrayData arrayData, ArrayType<?, ?> arrayT
}
return newArray;
}

public Object[] convertDateTime(InternalRow internalRow, StructType structType) {
Object[] fields =
Arrays.stream(((SpecificInternalRow) internalRow).values())
.map(MutableValue::boxed)
.toArray();
int len = structType.fields().length;
for (int i = 0; i < len; i++) {
DataType dataType = structType.fields()[i].dataType();
Object field = fields[i];
if (dataType == DataTypes.TimestampType && field instanceof Long) {
fields[i] = Timestamp.from(InstantConverterUtils.ofEpochMicro((long) field));
}
if (dataType == DataTypes.DateType && field instanceof Integer) {
fields[i] = Date.valueOf(LocalDate.ofEpochDay((int) field));
}
}
return fields;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;

import java.io.IOException;
Expand Down Expand Up @@ -244,23 +241,4 @@ private static Object reconvertArray(ArrayData arrayData, ArrayType<?, ?> arrayT
}
return newArray;
}

public Object[] convertDateTime(InternalRow internalRow, StructType structType) {
Object[] fields =
Arrays.stream(((SpecificInternalRow) internalRow).values())
.map(MutableValue::boxed)
.toArray();
int len = structType.fields().length;
for (int i = 0; i < len; i++) {
DataType dataType = structType.fields()[i].dataType();
Object field = fields[i];
if (dataType == DataTypes.TimestampType && field instanceof Long) {
fields[i] = Timestamp.from(InstantConverterUtils.ofEpochMicro((long) field));
}
if (dataType == DataTypes.DateType && field instanceof Integer) {
fields[i] = Date.valueOf(LocalDate.ofEpochDay((int) field));
}
}
return fields;
}
}