Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -126,6 +127,7 @@ public final class TestValuesTableFactory

private static final AtomicInteger idCounter = new AtomicInteger(0);
private static final Map<String, Collection<Row>> registeredData = new HashMap<>();
private static final Map<String, Collection<RowData>> registeredRowData = new HashMap<>();

/**
* Register the given data into the data factory context and return the data id. The data id can
Expand All @@ -145,6 +147,24 @@ public static String registerData(Seq<Row> data) {
return registerData(JavaScalaConversionUtil.toJava(data));
}

/**
* Register the given internal RowData into the data factory context and return the data id. The
* data id can be used as a reference to the registered data in data connector DDL.
*/
public static String registerRowData(Collection<RowData> data) {
String id = String.valueOf(idCounter.incrementAndGet());
registeredRowData.put(id, data);
return id;
}

/**
* Register the given internal RowData into the data factory context and return the data id. The
* data id can be used as a reference to the registered data in data connector DDL.
*/
public static String registerRowData(Seq<RowData> data) {
return registerRowData(JavaScalaConversionUtil.toJava(data));
}

/**
* Returns received raw results of the registered table sink. The raw results are encoded with
* {@link RowKind}.
Expand All @@ -171,6 +191,7 @@ public static List<Watermark> getWatermarkOutput(String tableName) {
/** Removes the registered data under the given data id. */
public static void clearAllData() {
registeredData.clear();
registeredRowData.clear();
TestValuesRuntimeFunctions.clearResults();
}

Expand Down Expand Up @@ -263,6 +284,14 @@ private static RowKind parseRowKind(String rowKindShortString) {
private static final ConfigOption<Boolean> ENABLE_WATERMARK_PUSH_DOWN =
ConfigOptions.key("enable-watermark-push-down").booleanType().defaultValue(false);

private static final ConfigOption<Boolean> INTERNAL_DATA =
ConfigOptions.key("register-internal-data")
.booleanType()
.defaultValue(false)
.withDescription(
"The registered data is internal type data, "
+ "which can be collected by the source directly.");

private static final ConfigOption<Map<String, String>> READABLE_METADATA =
ConfigOptions.key("readable-metadata")
.mapType()
Expand Down Expand Up @@ -325,6 +354,7 @@ public DynamicTableSource createDynamicTableSource(Context context) {
boolean enableWatermarkPushDown = helper.getOptions().get(ENABLE_WATERMARK_PUSH_DOWN);
boolean failingSource = helper.getOptions().get(FAILING_SOURCE);
int numElementToSkip = helper.getOptions().get(SOURCE_NUM_ELEMENT_TO_SKIP);
boolean internalData = helper.getOptions().get(INTERNAL_DATA);

Optional<List<String>> filterableFields =
helper.getOptions().getOptional(FILTERABLE_FIELDS);
Expand All @@ -336,6 +366,10 @@ public DynamicTableSource createDynamicTableSource(Context context) {
helper.getOptions().get(READABLE_METADATA), context.getClassLoader());

if (sourceClass.equals("DEFAULT")) {
if (internalData) {
return new TestValuesScanTableSourceWithInternalData(dataId, isBounded);
}

Collection<Row> data = registeredData.getOrDefault(dataId, Collections.emptyList());
List<Map<String, String>> partitions =
parsePartitionList(helper.getOptions().get(PARTITION_LIST));
Expand Down Expand Up @@ -505,7 +539,8 @@ public Set<ConfigOption<?>> optionalOptions() {
WRITABLE_METADATA,
ENABLE_WATERMARK_PUSH_DOWN,
SINK_DROP_LATE_EVENT,
SOURCE_NUM_ELEMENT_TO_SKIP));
SOURCE_NUM_ELEMENT_TO_SKIP,
INTERNAL_DATA));
}

private static int validateAndExtractRowtimeIndex(
Expand Down Expand Up @@ -1178,6 +1213,38 @@ public String asSummaryString() {
}
}

/** Values {@link ScanTableSource} which collects the registered {@link RowData} directly. */
private static class TestValuesScanTableSourceWithInternalData implements ScanTableSource {
private final String dataId;
private final boolean bounded;

public TestValuesScanTableSourceWithInternalData(String dataId, boolean bounded) {
this.dataId = dataId;
this.bounded = bounded;
}

@Override
public ChangelogMode getChangelogMode() {
return ChangelogMode.insertOnly();
}

@Override
public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderContext) {
final SourceFunction<RowData> sourceFunction = new FromRowDataSourceFunction(dataId);
return SourceFunctionProvider.of(sourceFunction, bounded);
}

@Override
public DynamicTableSource copy() {
return new TestValuesScanTableSourceWithInternalData(dataId, bounded);
}

@Override
public String asSummaryString() {
return "TestValuesWithInternalData";
}
}

// --------------------------------------------------------------------------------------------
// Table sinks
// --------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1393,4 +1460,33 @@ public String asSummaryString() {
return "TestSinkContextTableSink";
}
}

/**
* A {@link SourceFunction} which collects specific static {@link RowData} without
* serialization.
*/
private static class FromRowDataSourceFunction implements SourceFunction<RowData> {
private final String dataId;
private volatile boolean isRunning = true;

public FromRowDataSourceFunction(String dataId) {
this.dataId = dataId;
}

@Override
public void run(SourceContext<RowData> ctx) throws Exception {
Collection<RowData> values =
registeredRowData.getOrDefault(dataId, Collections.emptyList());
Iterator<RowData> valueIter = values.iterator();

while (isRunning && valueIter.hasNext()) {
ctx.collect(valueIter.next());
}
}

@Override
public void cancel() {
isRunning = false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,21 @@ import org.apache.flink.api.scala.typeutils.Types
import org.apache.flink.table.api._
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.api.internal.TableEnvironmentInternal
import org.apache.flink.table.data.{GenericRowData, RowData}
import org.apache.flink.table.data.{GenericRowData, MapData, RowData}
import org.apache.flink.table.planner.factories.TestValuesTableFactory
import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
import org.apache.flink.table.planner.runtime.utils._
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.runtime.typeutils.MapDataSerializerTest.CustomMapData
import org.apache.flink.table.types.logical.{BigIntType, IntType, VarCharType}
import org.apache.flink.table.utils.LegacyRowResource
import org.apache.flink.types.Row
import org.apache.flink.util.CollectionUtil

import java.util
import org.junit.Assert._
import org.junit._

import scala.collection.JavaConversions._
import scala.collection.Seq

class CalcITCase extends StreamingTestBase {
Expand Down Expand Up @@ -294,6 +297,47 @@ class CalcITCase extends StreamingTestBase {
)
}

@Test
def testSourceWithCustomInternalData(): Unit = {

def createMapData(k: Long, v: Long): MapData = {
val mapData = new util.HashMap[Long, Long]()
mapData.put(k, v)
new CustomMapData(mapData)
}

val rowData1: GenericRowData = new GenericRowData(2)
rowData1.setField(0, 1L)
rowData1.setField(1, createMapData(1L, 2L))
val rowData2: GenericRowData = new GenericRowData(2)
rowData2.setField(0, 2L)
rowData2.setField(1, createMapData(4L, 5L))
val values = List(rowData1, rowData2)

val myTableDataId = TestValuesTableFactory.registerRowData(values)

val ddl =
s"""
|CREATE TABLE CustomTable (
| a bigint,
| b map<bigint, bigint>
|) WITH (
| 'connector' = 'values',
| 'data-id' = '$myTableDataId',
| 'register-internal-data' = 'true',
| 'bounded' = 'true'
|)
""".stripMargin

env.getConfig.disableObjectReuse()
tEnv.executeSql(ddl)
val result = tEnv.executeSql( "select a, b from CustomTable")

val expected = List("1,{1=2}", "2,{4=5}")
val actual = CollectionUtil.iteratorToList(result.collect()).map(r => r.toString)
assertEquals(expected.sorted, actual.sorted)
}

@Test
def testSimpleProject(): Unit = {
val myTableDataId = TestValuesTableFactory.registerData(TestData.smallData3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.table.data.ArrayData;
import org.apache.flink.table.data.GenericMapData;
import org.apache.flink.table.data.MapData;
import org.apache.flink.table.data.binary.BinaryArrayData;
import org.apache.flink.table.data.binary.BinaryMapData;
Expand Down Expand Up @@ -104,10 +103,10 @@ public MapData createInstance() {
*/
@Override
public MapData copy(MapData from) {
if (from instanceof GenericMapData) {
return toBinaryMap(from);
} else {
if (from instanceof BinaryMapData) {
return ((BinaryMapData) from).copy();
} else {
return toBinaryMap(from);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import org.apache.flink.api.common.typeutils.SerializerTestBase;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.data.ArrayData;
import org.apache.flink.table.data.GenericArrayData;
import org.apache.flink.table.data.GenericMapData;
import org.apache.flink.table.data.MapData;
import org.apache.flink.table.data.StringData;
Expand All @@ -31,6 +33,7 @@

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.apache.flink.table.data.util.MapDataUtil.convertToJavaMap;

Expand Down Expand Up @@ -83,6 +86,7 @@ protected MapData[] getTestData() {
first.put(1, StringData.fromString(""));
return new MapData[] {
new GenericMapData(first),
new CustomMapData(first),
BinaryMapData.valueOf(
createArray(1, 2), ArrayDataSerializerTest.createArray("11", "haa")),
BinaryMapData.valueOf(
Expand All @@ -104,4 +108,51 @@ private static BinaryArrayData createArray(int... vs) {
writer.complete();
return array;
}

/** A simple custom implementation for {@link MapData}. */
public static class CustomMapData implements MapData {

private final Map<?, ?> map;

public CustomMapData(Map<?, ?> map) {
this.map = map;
}

public Object get(Object key) {
return map.get(key);
}

@Override
public int size() {
return map.size();
}

@Override
public ArrayData keyArray() {
Object[] keys = map.keySet().toArray();
return new GenericArrayData(keys);
}

@Override
public ArrayData valueArray() {
Object[] values = map.values().toArray();
return new GenericArrayData(values);
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (!(o instanceof CustomMapData)) {
return false;
}
return map.equals(((CustomMapData) o).map);
}

@Override
public int hashCode() {
return Objects.hash(map);
}
}
}