Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
package org.apache.flink.table.planner.runtime.harness

import org.apache.flink.api.scala._
import org.apache.flink.table.api.{EnvironmentSettings, _}
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl
import org.apache.flink.table.api.config.ExecutionConfigOptions.{TABLE_EXEC_MINIBATCH_ALLOW_LATENCY, TABLE_EXEC_MINIBATCH_ENABLED, TABLE_EXEC_MINIBATCH_SIZE}
import org.apache.flink.table.api.config.OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY
import org.apache.flink.table.api.{EnvironmentSettings, _}
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.runtime.utils.StreamingWithMiniBatchTestBase.{MiniBatchMode, MiniBatchOff, MiniBatchOn}
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode}
import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils.CountNullNonNull
import org.apache.flink.table.runtime.typeutils.RowDataSerializer
import org.apache.flink.table.runtime.util.RowDataHarnessAssertor
import org.apache.flink.table.runtime.util.StreamRecordUtils.binaryRecord
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.types.Row
import org.apache.flink.types.RowKind._

Expand All @@ -46,7 +50,7 @@ import scala.collection.mutable

@RunWith(classOf[Parameterized])
class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode)
extends HarnessTestBase(mode) {
extends HarnessTestBase(mode) {

@Before
override def before(): Unit = {
Expand Down Expand Up @@ -100,7 +104,7 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode
testHarness.setStateTtlProcessingTime(1)

// insertion
testHarness.processElement(binaryRecord(INSERT,"aaa", 1L: JLong))
testHarness.processElement(binaryRecord(INSERT, "aaa", 1L: JLong))
expectedOutput.add(binaryRecord(INSERT, "aaa", 1L: JLong))

// insertion
Expand Down Expand Up @@ -144,7 +148,7 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode
expectedOutput.add(binaryRecord(INSERT, "eee", 6L: JLong))

// retract
testHarness.processElement(binaryRecord(INSERT,"aaa", 7L: JLong))
testHarness.processElement(binaryRecord(INSERT, "aaa", 7L: JLong))
expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 9L: JLong))
expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong))
testHarness.processElement(binaryRecord(INSERT, "bbb", 3L: JLong))
Expand All @@ -160,28 +164,8 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode

@Test
def testAggregationWithDistinct(): Unit = {
val data = new mutable.MutableList[(String, String, Long)]
val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
tEnv.createTemporaryView("T", t)
tEnv.createTemporarySystemFunction("CntNullNonNull", new CountNullNonNull)

val sql =
"""
|SELECT a, COUNT(DISTINCT b), CntNullNonNull(DISTINCT b), COUNT(*), SUM(c)
|FROM T
|GROUP BY a
""".stripMargin
val t1 = tEnv.sqlQuery(sql)

tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(2))
val testHarness = createHarnessTester(t1.toRetractStream[Row], "GroupAggregate")
val assertor = new RowDataHarnessAssertor(
Array(
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType))
val (testHarness, outputTypes) = createAggregationWithDistinct
val assertor = new RowDataHarnessAssertor(outputTypes)

testHarness.open()

Expand All @@ -191,7 +175,7 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode
testHarness.setStateTtlProcessingTime(1)

// insertion
testHarness.processElement(binaryRecord(INSERT,"aaa", "a1", 1L: JLong))
testHarness.processElement(binaryRecord(INSERT, "aaa", "a1", 1L: JLong))
expectedOutput.add(binaryRecord(INSERT, "aaa", 1L: JLong, "1|0", 1L: JLong, 1L: JLong))

// insertion
Expand Down Expand Up @@ -240,6 +224,41 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode
testHarness.close()
}

private def createAggregationWithDistinct()
: (KeyedOneInputStreamOperatorTestHarness[RowData, RowData, RowData], Array[LogicalType]) = {
val data = new mutable.MutableList[(String, String, Long)]
val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c)
tEnv.createTemporaryView("T", t)
tEnv.createTemporarySystemFunction("CntNullNonNull", new CountNullNonNull)

val sql =
"""
|SELECT a, COUNT(DISTINCT b), CntNullNonNull(DISTINCT b), COUNT(*), SUM(c)
|FROM T
|GROUP BY a
""".stripMargin
val t1 = tEnv.sqlQuery(sql)

tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(2))
val testHarness = createHarnessTester(t1.toRetractStream[Row], "GroupAggregate")
val outputTypes = Array(
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType)

(testHarness, outputTypes)
}

@Test
def testCloseWithoutOpen(): Unit = {
val (testHarness, outputType) = createAggregationWithDistinct
testHarness.setup(new RowDataSerializer(outputType: _*))
// simulate a failover after a failed task open(e.g., stuck on initializing)
// expect no exception happens
testHarness.close()
}
}

object GroupAggregateHarnessTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import org.apache.flink.runtime.state.StateBackend
import org.apache.flink.runtime.state.memory.MemoryStateBackend
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.streaming.api.scala.DataStream
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.streaming.api.transformations.{OneInputTransformation, PartitionTransformation}
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness}
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.JLong
import org.apache.flink.table.planner.runtime.utils.StreamingTestBase
Expand Down Expand Up @@ -86,6 +86,19 @@ class HarnessTestBase(mode: StateBackendMode) extends StreamingTestBase {
.asInstanceOf[KeyedOneInputStreamOperatorTestHarness[RowData, RowData, RowData]]
}

def createHarnessTesterForNoState(
ds: DataStream[_],
prefixOperatorName: String)
: OneInputStreamOperatorTestHarness[RowData, RowData] = {
val transformation = extractExpectedTransformation(
ds.javaStream.getTransformation,
prefixOperatorName)
val processOperator = transformation.getOperator
.asInstanceOf[OneInputStreamOperator[Any, Any]]
new OneInputStreamOperatorTestHarness(processOperator)
.asInstanceOf[OneInputStreamOperatorTestHarness[RowData, RowData]]
}

private def extractExpectedTransformation(
t: Transformation[_],
prefixOperatorName: String): OneInputTransformation[_, _] = {
Expand All @@ -96,6 +109,8 @@ class HarnessTestBase(mode: StateBackendMode) extends StreamingTestBase {
} else {
extractExpectedTransformation(one.getInputs.get(0), prefixOperatorName)
}
case p: PartitionTransformation[_] =>
extractExpectedTransformation(p.getInputs.get(0), prefixOperatorName)
case _ => throw new Exception(
s"Can not find the expected $prefixOperatorName transformation")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ package org.apache.flink.table.planner.runtime.harness

import org.apache.flink.api.common.time.Time
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness
import org.apache.flink.table.api._
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.runtime.typeutils.RowDataSerializer
import org.apache.flink.table.runtime.util.RowDataHarnessAssertor
import org.apache.flink.table.runtime.util.StreamRecordUtils.{binaryrow, row}
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.types.Row

import org.junit.runner.RunWith
Expand All @@ -52,34 +55,8 @@ class OverAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(m

@Test
def testProcTimeBoundedRowsOver(): Unit = {

val data = new mutable.MutableList[(Long, String, Long)]
val t = env.fromCollection(data).toTable(tEnv, 'currtime, 'b, 'c, 'proctime.proctime)
tEnv.registerTable("T", t)

val sql =
"""
|SELECT currtime, b, c,
| min(c) OVER
| (PARTITION BY b ORDER BY proctime ROWS BETWEEN 1 PRECEDING AND CURRENT ROW),
| max(c) OVER
| (PARTITION BY b ORDER BY proctime ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)
|FROM T
""".stripMargin
val t1 = tEnv.sqlQuery(sql)

tEnv.getConfig.setIdleStateRetentionTime(Time.seconds(2), Time.seconds(4))
val testHarness = createHarnessTester(t1.toAppendStream[Row], "OverAggregate")
val assertor = new RowDataHarnessAssertor(
Array(
DataTypes.BIGINT().getLogicalType,
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType))

val (testHarness, outputType) = createProcTimeBoundedRowsOver
val assertor = new RowDataHarnessAssertor(outputType)
testHarness.open()

// register cleanup timer with 3001
Expand Down Expand Up @@ -161,6 +138,36 @@ class OverAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(m
testHarness.close()
}

private def createProcTimeBoundedRowsOver()
: (KeyedOneInputStreamOperatorTestHarness[RowData, RowData, RowData], Array[LogicalType]) = {
val data = new mutable.MutableList[(Long, String, Long)]
val t = env.fromCollection(data).toTable(tEnv, 'currtime, 'b, 'c, 'proctime.proctime)
tEnv.registerTable("T", t)

val sql =
"""
|SELECT currtime, b, c,
| min(c) OVER
| (PARTITION BY b ORDER BY proctime ROWS BETWEEN 1 PRECEDING AND CURRENT ROW),
| max(c) OVER
| (PARTITION BY b ORDER BY proctime ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)
|FROM T
""".stripMargin
val t1 = tEnv.sqlQuery(sql)

tEnv.getConfig.setIdleStateRetentionTime(Time.seconds(2), Time.seconds(4))
val testHarness = createHarnessTester(t1.toAppendStream[Row], "OverAggregate")
val outputType = Array(
DataTypes.BIGINT().getLogicalType,
DataTypes.STRING().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType,
DataTypes.BIGINT().getLogicalType)
(testHarness, outputType)
}

/**
* NOTE: all elements at the same proc timestamp have the same value per key
*/
Expand Down Expand Up @@ -940,4 +947,12 @@ class OverAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(m
assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result)
testHarness.close()
}

@Test
def testCloseWithoutOpen(): Unit = {
val (testHarness, outputType) = createProcTimeBoundedRowsOver
testHarness.setup(new RowDataSerializer(outputType: _*))
// simulate a failover after a failed task open, expect no exception happens
testHarness.open()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,27 @@

package org.apache.flink.table.planner.runtime.harness

import java.lang.{Integer => JInt}
import java.util.concurrent.ConcurrentLinkedQueue
import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl
import org.apache.flink.table.api.EnvironmentSettings
import org.apache.flink.table.api.{EnvironmentSettings, _}
import org.apache.flink.table.data.RowData
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.planner.utils.{Top3WithMapView, Top3WithRetractInput}
import org.apache.flink.table.runtime.typeutils.RowDataSerializer
import org.apache.flink.table.runtime.util.RowDataHarnessAssertor
import org.apache.flink.table.runtime.util.StreamRecordUtils.{deleteRecord, insertRecord}
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.types.Row

import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{Before, Test}

import java.lang.{Integer => JInt}
import java.time.Duration
import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable

Expand Down Expand Up @@ -117,22 +121,8 @@ class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(

@Test
def testTableAggregateWithRetractInput(): Unit = {
val top3 = new Top3WithRetractInput
tEnv.registerFunction("top3", top3)
val source = env.fromCollection(data).toTable(tEnv, 'a, 'b)
val resultTable = source
.groupBy('a)
.select('b.sum as 'b)
.flatAggregate(top3('b) as ('b1, 'b2))
.select('b1, 'b2)

tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(2))
val testHarness = createHarnessTester(
resultTable.toRetractStream[Row], "GroupTableAggregate")
val assertor = new RowDataHarnessAssertor(
Array(
DataTypes.INT().getLogicalType,
DataTypes.INT().getLogicalType))
val (testHarness, outputTypes) = createTableAggregateWithRetract
val assertor = new RowDataHarnessAssertor(outputTypes)

testHarness.open()
val expectedOutput = new ConcurrentLinkedQueue[Object]()
Expand Down Expand Up @@ -170,4 +160,32 @@ class TableAggregateHarnessTest(mode: StateBackendMode) extends HarnessTestBase(
assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result)
testHarness.close()
}

private def createTableAggregateWithRetract()
: (KeyedOneInputStreamOperatorTestHarness[RowData, RowData, RowData], Array[LogicalType]) = {
val top3 = new Top3WithRetractInput
tEnv.registerFunction("top3", top3)
val source = env.fromCollection(data).toTable(tEnv, 'a, 'b)
val resultTable = source
.groupBy('a)
.select('b.sum as 'b)
.flatAggregate(top3('b) as('b1, 'b2))
.select('b1, 'b2)

tEnv.getConfig.setIdleStateRetention(Duration.ofSeconds(2))
val testHarness = createHarnessTester(
resultTable.toRetractStream[Row], "GroupTableAggregate")
val outputTypes = Array(
DataTypes.INT().getLogicalType,
DataTypes.INT().getLogicalType)
(testHarness, outputTypes)
}

@Test
def testCloseWithoutOpen(): Unit = {
val (testHarness, outputTypes) = createTableAggregateWithRetract
testHarness.setup(new RowDataSerializer(outputTypes: _*))
// simulate a failover after a failed task open, expect no exception happens
testHarness.close()
}
}
Loading