Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flink-python/pyflink/fn_execution/table/aggregate_fast.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ cdef class SimpleAggsHandleFunction(SimpleAggsHandleFunctionBase):
cdef size_t _get_value_indexes_length

cdef class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase):
pass
cdef list _convert_to_row(self, data)

cdef class RecordCounter:
cdef bint record_count_is_zero(self, list acc)
Expand Down
13 changes: 11 additions & 2 deletions flink-python/pyflink/fn_execution/table/aggregate_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ from typing import List, Dict

from apache_beam.coders import PickleCoder, Coder

from pyflink.common import Row
from pyflink.fn_execution.table.state_data_view import DataViewSpec, ListViewSpec, MapViewSpec, \
PerKeyStateDataViewStore
from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
Expand Down Expand Up @@ -379,12 +380,20 @@ cdef class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase):
results = []
for x in udf.emit_value(self._accumulators[0]):
if is_retract:
result = join_row(current_key, x._values, InternalRowKind.DELETE)
result = join_row(current_key, self._convert_to_row(x), InternalRowKind.DELETE)
else:
result = join_row(current_key, x._values, InternalRowKind.INSERT)
result = join_row(current_key, self._convert_to_row(x), InternalRowKind.INSERT)
results.append(result)
return results

cdef list _convert_to_row(self, data):
if isinstance(data, Row):
return data._values
elif isinstance(data, tuple):
return list(data)
else:
return [data]

cdef class RecordCounter:
"""
The RecordCounter is used to count the number of input records under the current key.
Expand Down
10 changes: 9 additions & 1 deletion flink-python/pyflink/fn_execution/table/aggregate_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,21 @@ def emit_value(self, current_key: List, is_retract: bool):
udf = self._udfs[0] # type: TableAggregateFunction
results = udf.emit_value(self._accumulators[0])
for x in results:
result = join_row(current_key, x._values)
result = join_row(current_key, self._convert_to_row(x))
if is_retract:
result.set_row_kind(RowKind.DELETE)
else:
result.set_row_kind(RowKind.INSERT)
yield result

def _convert_to_row(self, data):
if isinstance(data, Row):
return data._values
elif isinstance(data, tuple):
return list(data)
else:
return [data]


class RecordCounter(ABC):
"""
Expand Down
9 changes: 4 additions & 5 deletions flink-python/pyflink/table/tests/test_row_based_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_flat_aggregate(self):
(2, 'Hi', 'Hello')], ['a', 'b', 'c'])
result = t.select(t.a, t.c) \
.group_by(t.c) \
.flat_aggregate(mytop) \
.flat_aggregate(mytop.alias('a')) \
.select(t.a) \
.flat_aggregate(mytop.alias("b")) \
.select("b") \
Expand Down Expand Up @@ -339,8 +339,8 @@ def get_result_type(self):
class Top2(TableAggregateFunction):

def emit_value(self, accumulator):
yield Row(accumulator[0])
yield Row(accumulator[1])
yield accumulator[0]
yield accumulator[1]

def create_accumulator(self):
return [None, None]
Expand All @@ -365,8 +365,7 @@ def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())

def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT())])
return DataTypes.BIGINT()


class ListViewConcatTableAggregateFunction(TableAggregateFunction):
Expand Down