Skip to content

Commit

Permalink
Remove the support of int argument type in ordered list state apis.
Browse files Browse the repository at this point in the history
  • Loading branch information
shunping committed Oct 2, 2024
1 parent 5524337 commit 0dbdd41
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 69 deletions.
35 changes: 15 additions & 20 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ def __str__(self) -> str:
class SynchronousOrderedListRuntimeState(userstate.OrderedListRuntimeState):
RANGE_MIN = -(1 << 63)
RANGE_MAX = (1 << 63) - 1
TIMESTAMP_RANGE_MIN = timestamp.Timestamp(micros=RANGE_MIN)
TIMESTAMP_RANGE_MAX = timestamp.Timestamp(micros=RANGE_MAX)

def __init__(
self,
Expand All @@ -777,30 +779,26 @@ def __init__(
self._pending_adds = SortedDict()
self._pending_removes = RangeSet()

def add(self, elem: Tuple[Union[int, timestamp.Timestamp], Any]) -> None:
def add(self, elem: Tuple[timestamp.Timestamp, Any]) -> None:
assert len(elem) == 2
key, value = elem
if isinstance(key, timestamp.Timestamp):
key = key.micros
key = key.micros

if key >= self.RANGE_MAX or key < self.RANGE_MIN:
raise ValueError("key value %d is out of range" % key)
self._pending_adds.setdefault(key, []).append(value)

def read(self) -> Iterable[Tuple[timestamp.Timestamp, Any]]:
return self.read_range(self.RANGE_MIN, self.RANGE_MAX)
return self.read_range(self.TIMESTAMP_RANGE_MIN, self.TIMESTAMP_RANGE_MAX)

def read_range(
self,
min_timestamp: Union[int, timestamp.Timestamp],
limit_timestamp: Union[int, timestamp.Timestamp]
min_timestamp: timestamp.Timestamp,
limit_timestamp: timestamp.Timestamp
) -> Iterable[Tuple[timestamp.Timestamp, Any]]:
# convert timestamp to int, as sort keys are stored as int internally.
if isinstance(min_timestamp, timestamp.Timestamp):
min_timestamp = min_timestamp.micros

if isinstance(limit_timestamp, timestamp.Timestamp):
limit_timestamp = limit_timestamp.micros
min_timestamp = min_timestamp.micros
limit_timestamp = limit_timestamp.micros

keys_to_add = self._pending_adds.irange(
min_timestamp, limit_timestamp, inclusive=(True, False))
Expand Down Expand Up @@ -830,10 +828,10 @@ def read_range(
self._state_handler, range_query_state_key, self._elem_coder))

return map(
lambda x: (timestamp.Timestamp(x[0]), x[1]),
lambda x: (timestamp.Timestamp(micros=x[0]), x[1]),
heapq.merge(persistent_items, local_items))

return map(lambda x: (timestamp.Timestamp(x[0]), x[1]), local_items)
return map(lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), local_items)

def clear(self) -> None:
self._cleared = True
Expand All @@ -843,13 +841,10 @@ def clear(self) -> None:

def clear_range(
self,
min_timestamp: Union[int, timestamp.Timestamp],
limit_timestamp: Union[int, timestamp.Timestamp]) -> None:
if isinstance(min_timestamp, timestamp.Timestamp):
min_timestamp = min_timestamp.micros

if isinstance(limit_timestamp, timestamp.Timestamp):
limit_timestamp = limit_timestamp.micros
min_timestamp: timestamp.Timestamp,
limit_timestamp: timestamp.Timestamp) -> None:
min_timestamp = min_timestamp.micros
limit_timestamp = limit_timestamp.micros

# materialize the keys to remove before the actual removal
keys_to_remove = list(
Expand Down
127 changes: 84 additions & 43 deletions sdks/python/apache_beam/runners/worker/bundle_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.transforms import userstate
from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils import timestamp
from apache_beam.utils.windowed_value import WindowedValue


Expand Down Expand Up @@ -446,27 +447,36 @@ def setUp(self):
self.state = self._create_state()

def test_read_range(self):
A1, B1, A4 = [(1, "a1"), (1, "b1"), (4, "a4")]
self.assertEqual([], list(self.state.read_range(0, 5)))
T0 = timestamp.Timestamp.of(0)
T1 = timestamp.Timestamp.of(1)
T2 = timestamp.Timestamp.of(2)
T3 = timestamp.Timestamp.of(3)
T4 = timestamp.Timestamp.of(4)
T5 = timestamp.Timestamp.of(5)
T9 = timestamp.Timestamp.of(9)
A1, B1, A4 = [(T1, "a1"), (T1, "b1"), (T4, "a4")]
self.assertEqual([], list(self.state.read_range(T0, T5)))

self.state.add(A1)
self.assertEqual([A1], list(self.state.read_range(0, 5)))
self.assertEqual([A1], list(self.state.read_range(T0, T5)))

self.state.add(B1)
self.assertEqual([A1, B1], list(self.state.read_range(0, 5)))
self.assertEqual([A1, B1], list(self.state.read_range(T0, T5)))

self.state.add(A4)
self.assertEqual([A1, B1, A4], list(self.state.read_range(0, 5)))
self.assertEqual([A1, B1, A4], list(self.state.read_range(T0, T5)))

self.assertEqual([], list(self.state.read_range(0, 1)))
self.assertEqual([], list(self.state.read_range(5, 10)))
self.assertEqual([A1, B1], list(self.state.read_range(1, 2)))
self.assertEqual([], list(self.state.read_range(2, 3)))
self.assertEqual([], list(self.state.read_range(2, 4)))
self.assertEqual([A4], list(self.state.read_range(4, 5)))
self.assertEqual([], list(self.state.read_range(T0, T1)))
self.assertEqual([], list(self.state.read_range(T5, T9)))
self.assertEqual([A1, B1], list(self.state.read_range(T1, T2)))
self.assertEqual([], list(self.state.read_range(T2, T3)))
self.assertEqual([], list(self.state.read_range(T2, T4)))
self.assertEqual([A4], list(self.state.read_range(T4, T5)))

def test_read(self):
A1, B1, A4 = [(1, "a1"), (1, "b1"), (4, "a4")]
T1 = timestamp.Timestamp.of(1)
T4 = timestamp.Timestamp.of(4)
A1, B1, A4 = [(T1, "a1"), (T1, "b1"), (T4, "a4")]
self.assertEqual([], list(self.state.read()))

self.state.add(A1)
Expand All @@ -482,8 +492,14 @@ def test_read(self):
self.assertEqual([A1, A1, B1, A4], list(self.state.read()))

def test_clear_range(self):
A1, B1, A4, A5 = [(1, "a1"), (1, "b1"), (4, "a4"), (5, "a5")]
self.state.clear_range(0, 1)
T0 = timestamp.Timestamp.of(0)
T1 = timestamp.Timestamp.of(1)
T2 = timestamp.Timestamp.of(2)
T3 = timestamp.Timestamp.of(3)
T4 = timestamp.Timestamp.of(4)
T5 = timestamp.Timestamp.of(5)
A1, B1, A4, A5 = [(T1, "a1"), (T1, "b1"), (T4, "a4"), (T5, "a5")]
self.state.clear_range(T0, T1)
self.assertEqual([], list(self.state.read()))

self.state.add(A1)
Expand All @@ -492,30 +508,34 @@ def test_clear_range(self):
self.state.add(A5)
self.assertEqual([A1, B1, A4, A5], list(self.state.read()))

self.state.clear_range(0, 1)
self.state.clear_range(T0, T1)
self.assertEqual([A1, B1, A4, A5], list(self.state.read()))

self.state.clear_range(1, 2)
self.state.clear_range(T1, T2)
self.assertEqual([A4, A5], list(self.state.read()))

# no side effect on clearing the same range twice
self.state.clear_range(1, 2)
self.state.clear_range(T1, T2)
self.assertEqual([A4, A5], list(self.state.read()))

self.state.clear_range(3, 4)
self.state.clear_range(T3, T4)
self.assertEqual([A4, A5], list(self.state.read()))

self.state.clear_range(3, 5)
self.state.clear_range(T3, T5)
self.assertEqual([A5], list(self.state.read()))

def test_add_and_clear_range_after_commit(self):
A1, B1, C1, A4, A5, A6 = [(1, "a1"), (1, "b1"), (1, "c1"),
(4, "a4"), (5, "a5"), (6, "a6")]
T1 = timestamp.Timestamp.of(1)
T4 = timestamp.Timestamp.of(4)
T5 = timestamp.Timestamp.of(5)
T6 = timestamp.Timestamp.of(6)
A1, B1, C1, A4, A5, A6 = [(T1, "a1"), (T1, "b1"), (T1, "c1"),
(T4, "a4"), (T5, "a5"), (T6, "a6")]
self.state.add(A1)
self.state.add(B1)
self.state.add(A4)
self.state.add(A5)
self.state.clear_range(4, 5)
self.state.clear_range(T4, T5)
self.assertEqual([A1, B1, A5], list(self.state.read()))

self.state.commit()
Expand All @@ -527,7 +547,7 @@ def test_add_and_clear_range_after_commit(self):
self.state.add(A6)
self.assertEqual([A1, B1, C1, A5, A6], list(self.state.read()))

self.state.clear_range(5, 6)
self.state.clear_range(T5, T6)
self.assertEqual([A1, B1, C1, A6], list(self.state.read()))

self.state.commit()
Expand All @@ -536,18 +556,22 @@ def test_add_and_clear_range_after_commit(self):
self.assertEqual([A1, B1, C1, A6], list(self.state.read()))

def test_clear(self):
A1, B1, C1, A4, A5, B5 = [(1, "a1"), (1, "b1"), (1, "c1"),
(4, "a4"), (5, "a5"), (5, "b5")]
T1 = timestamp.Timestamp.of(1)
T4 = timestamp.Timestamp.of(4)
T5 = timestamp.Timestamp.of(5)
T9 = timestamp.Timestamp.of(9)
A1, B1, C1, A4, A5, B5 = [(T1, "a1"), (T1, "b1"), (T1, "c1"),
(T4, "a4"), (T5, "a5"), (T5, "b5")]
self.state.add(A1)
self.state.add(B1)
self.state.add(A4)
self.state.add(A5)
self.state.clear_range(4, 5)
self.state.clear_range(T4, T5)
self.assertEqual([A1, B1, A5], list(self.state.read()))
self.state.commit()

self.state.add(C1)
self.state.clear_range(5, 10)
self.state.clear_range(T5, T9)
self.assertEqual([A1, B1, C1], list(self.state.read()))
self.state.clear()
self.assertEqual(len(self.state._pending_adds), 0)
Expand All @@ -563,7 +587,10 @@ def test_clear(self):
self.assertEqual([B5], list(self.state.read()))

def test_multiple_iterators(self):
A1, B1, A3, B3 = [(1, "a1"), (1, "b1"), (3, "a3"), (3, "b3")]
T1 = timestamp.Timestamp.of(1)
T3 = timestamp.Timestamp.of(3)
T9 = timestamp.Timestamp.of(9)
A1, B1, A3, B3 = [(T1, "a1"), (T1, "b1"), (T3, "a3"), (T3, "b3")]
self.state.add(A1)
self.state.add(A3)
self.state.commit()
Expand All @@ -578,7 +605,7 @@ def test_multiple_iterators(self):
self.state.add(B3)
iter_before_clear_range = iter(self.state.read())
self.assertEqual(A1, next(iter_before_clear_range))
self.state.clear_range(3, 10)
self.state.clear_range(T3, T9)
self.assertEqual(B1, next(iter_before_clear_range))
self.assertEqual(A3, next(iter_before_clear_range))
self.assertEqual(B3, next(iter_before_clear_range))
Expand All @@ -601,10 +628,13 @@ def __init__(self):

def add(self, elem):
k, v = elem
k = k.micros
self._data[k - lower].append(v)
self._logs.append("add(%d, %s)" % (k, v))

def clear_range(self, lo, hi):
lo = lo.micros
hi = hi.micros
for i in range(lo, hi):
self._data[i - lower] = []
self._logs.append("clear_range(%d, %d)" % (lo, hi))
Expand All @@ -618,7 +648,7 @@ def read(self):
self._logs.append("read()")
for i in range(len(self._data)):
for v in self._data[i]:
yield (i + lower, v)
yield (timestamp.Timestamp(micros=(i + lower)), v)

random.seed(seed)

Expand All @@ -630,13 +660,15 @@ def read(self):
op = random.randint(1, 100)
if 1 <= op < 70:
num = random.randint(lower, upper)
state.add((num, "a%d" % num))
bench_state.add((num, "a%d" % num))
state.add((timestamp.Timestamp(micros=num), "a%d" % num))
bench_state.add((timestamp.Timestamp(micros=num), "a%d" % num))
elif 70 <= op < 95:
num1 = random.randint(lower, upper)
num2 = random.randint(lower, upper)
state.clear_range(min(num1, num2), max(num1, num2))
bench_state.clear_range(min(num1, num2), max(num1, num2))
min_time = timestamp.Timestamp(micros=min(num1, num2))
max_time = timestamp.Timestamp(micros=max(num1, num2))
state.clear_range(min_time, max_time)
bench_state.clear_range(min_time, max_time)
elif op >= 95:
state.clear()
bench_state.clear()
Expand Down Expand Up @@ -664,33 +696,42 @@ def test_fuzz(self):
raise RuntimeError("Exception occurred on seed=%d: %s" % (seed, e))

def test_min_max(self):
INT64_MIN, INT64_MAX_MINUS_ONE, INT64_MAX = [(-(1 << 63), "min"),
((1 << 63) - 2, "max"),
((1 << 63) - 1, "err")]
T_MIN = timestamp.Timestamp(micros=(-(1 << 63)))
T_MAX_MINUS_ONE = timestamp.Timestamp(micros=((1 << 63) - 2))
T_MAX = timestamp.Timestamp(micros=((1 << 63) - 1))
T0 = timestamp.Timestamp(micros=0)
INT64_MIN, INT64_MAX_MINUS_ONE, INT64_MAX = [(T_MIN, "min"),
(T_MAX_MINUS_ONE, "max"),
(T_MAX, "err")]
self.state.add(INT64_MIN)
self.state.add(INT64_MAX_MINUS_ONE)
self.assertRaises(ValueError, lambda: self.state.add(INT64_MAX))

self.assertEqual([INT64_MIN, INT64_MAX_MINUS_ONE], list(self.state.read()))
self.assertEqual([INT64_MIN], list(self.state.read_range(-(1 << 63), 0)))
self.assertEqual([INT64_MIN], list(self.state.read_range(T_MIN, T0)))
self.assertEqual([INT64_MAX_MINUS_ONE],
list(self.state.read_range(0, (1 << 63) - 1)))
list(self.state.read_range(T0, T_MAX)))

def test_continuation_token(self):
A1, A2, A7, B7, A8 = [(1, "a1"), (2, "a2"), (7, "a7"), (7, "b7"), (8, "a8")]
T1 = timestamp.Timestamp.of(1)
T2 = timestamp.Timestamp.of(2)
T7 = timestamp.Timestamp.of(7)
T8 = timestamp.Timestamp.of(8)
A1, A2, A7, B7, A8 = [(T1, "a1"), (T2, "a2"), (T7, "a7"),
(T7, "b7"), (T8, "a8")]
self.state._state_handler._underlying._use_continuation_tokens = True
self.assertEqual([], list(self.state.read_range(1, 8)))
self.assertEqual([], list(self.state.read_range(T1, T8)))

self.state.add(A1)
self.state.add(A2)
self.state.add(A7)
self.state.add(B7)
self.state.add(A8)

self.assertEqual([A2, A7, B7], list(self.state.read_range(2, 8)))
self.assertEqual([A2, A7, B7], list(self.state.read_range(T2, T8)))

self.state.commit()
self.assertEqual([A2, A7, B7], list(self.state.read_range(2, 8)))
self.assertEqual([A2, A7, B7], list(self.state.read_range(T2, T8)))

self.assertEqual([A1, A2, A7, B7, A8], list(self.state.read()))

Expand Down
11 changes: 5 additions & 6 deletions sdks/python/apache_beam/transforms/userstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,20 +389,19 @@ class OrderedListRuntimeState(AccumulatingRuntimeState):
def read(self) -> Iterable[Tuple[Timestamp, Any]]:
raise NotImplementedError(type(self))

def add(self, value: Tuple[Union[int, Timestamp], Any]) -> None:
def add(self, value: Tuple[Timestamp, Any]) -> None:
raise NotImplementedError(type(self))

def read_range(
self,
min_time_stamp: Union[int, Timestamp],
limit_time_stamp: Union[int,
Timestamp]) -> Iterable[Tuple[Timestamp, Any]]:
min_time_stamp: Timestamp,
limit_time_stamp: Timestamp) -> Iterable[Tuple[Timestamp, Any]]:
raise NotImplementedError(type(self))

def clear_range(
self,
min_time_stamp: Union[int, Timestamp],
limit_time_stamp: Union[int, Timestamp]) -> None:
min_time_stamp: Timestamp,
limit_time_stamp: Timestamp) -> None:
raise NotImplementedError(type(self))


Expand Down

0 comments on commit 0dbdd41

Please sign in to comment.