Skip to content

Commit

Permalink
Support ordered list states in python sdk and fnapi runner (#32326)
Browse files Browse the repository at this point in the history
* Support ordered list state in python sdk and fnapi runner.

* Add test to verify integrity of multiple iterators

* Add fuzz tests and fix two edge cases.

* Add sortedcontainer in package dependency

* Code refactoring and add a check for the supported maximum key

* regenerate requirments for python images.

* Refactor portable runner code for ordered list state

* Return continuation tokens in portable runnner for ordered list state

* Fix some lints

* Apply yapf

* Fix lints

* Sync base image requirements with master.

* Add typing for ordered list state apis.

* Add typing to orderedliststate user state.

* Fix a typo.

* Refactor some code based on the feedback.

* Fix lints

* Remove the support of int argument type in ordered list state apis.

* Fix formats and lints

* More lints

* Refactor the code to use the continuation token logic.

* Fix lints
  • Loading branch information
shunping authored Oct 9, 2024
1 parent c31d81c commit 7177baf
Show file tree
Hide file tree
Showing 5 changed files with 604 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
from typing import overload

import grpc
from sortedcontainers import SortedSet

from apache_beam import coders
from apache_beam.io import filesystems
from apache_beam.io.filesystems import CompressionTypes
from apache_beam.portability import common_urns
Expand Down Expand Up @@ -959,7 +961,8 @@ class StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer,
'multimap_keys_values_side_input',
'iterable_side_input',
'bag_user_state',
'multimap_user_state'
'multimap_user_state',
'ordered_list_user_state'
])

class CopyOnWriteState(object):
Expand Down Expand Up @@ -1021,6 +1024,8 @@ def __init__(self):
self._checkpoint = None # type: Optional[StateServicer.StateType]
self._use_continuation_tokens = False
self._continuations = {} # type: Dict[bytes, Tuple[bytes, ...]]
self._ordered_list_keys = collections.defaultdict(
SortedSet) # type: DefaultDict[bytes, SortedSet]

def checkpoint(self):
# type: () -> None
Expand Down Expand Up @@ -1050,6 +1055,14 @@ def process_instruction_id(self, unused_instruction_id):
# type: (Any) -> Iterator
yield

def _get_one_interval_key(self, state_key, start):
# type: (beam_fn_api_pb2.StateKey, int) -> bytes
state_key_copy = beam_fn_api_pb2.StateKey()
state_key_copy.CopyFrom(state_key)
state_key_copy.ordered_list_user_state.range.start = start
state_key_copy.ordered_list_user_state.range.end = start + 1
return self._to_key(state_key_copy)

def get_raw(self,
state_key, # type: beam_fn_api_pb2.StateKey
continuation_token=None # type: Optional[bytes]
Expand All @@ -1061,7 +1074,30 @@ def get_raw(self,
'Unknown state type: ' + state_key.WhichOneof('type'))

with self._lock:
full_state = self._state[self._to_key(state_key)]
if not continuation_token:
# Compute full_state only when no continuation token is provided.
# If there is continuation token, full_state is already in
# continuation cache. No need to recompute.
full_state = [] # type: List[bytes]
if state_key.WhichOneof('type') == 'ordered_list_user_state':
maybe_start = state_key.ordered_list_user_state.range.start
maybe_end = state_key.ordered_list_user_state.range.end
persistent_state_key = beam_fn_api_pb2.StateKey()
persistent_state_key.CopyFrom(state_key)
persistent_state_key.ordered_list_user_state.ClearField("range")

available_keys = self._ordered_list_keys[self._to_key(
persistent_state_key)]

for i in available_keys.irange(maybe_start,
maybe_end,
inclusive=(True, False)):
entries = self._state[self._get_one_interval_key(
persistent_state_key, i)]
full_state.extend(entries)
else:
full_state.extend(self._state[self._to_key(state_key)])

if self._use_continuation_tokens:
# The token is "nonce:index".
if not continuation_token:
Expand All @@ -1087,14 +1123,40 @@ def append_raw(
):
# type: (...) -> _Future
with self._lock:
self._state[self._to_key(state_key)].append(data)
if state_key.WhichOneof('type') == 'ordered_list_user_state':
coder = coders.TupleCoder([
coders.VarIntCoder(),
coders.coders.LengthPrefixCoder(coders.BytesCoder())
]).get_impl()

for key, value in coder.decode_all(data):
self._state[self._get_one_interval_key(state_key, key)].append(
coder.encode((key, value)))
self._ordered_list_keys[self._to_key(state_key)].add(key)
else:
self._state[self._to_key(state_key)].append(data)
return _Future.done()

def clear(self, state_key):
# type: (beam_fn_api_pb2.StateKey) -> _Future
with self._lock:
try:
del self._state[self._to_key(state_key)]
if state_key.WhichOneof('type') == 'ordered_list_user_state':
start = state_key.ordered_list_user_state.range.start
end = state_key.ordered_list_user_state.range.end
persistent_state_key = beam_fn_api_pb2.StateKey()
persistent_state_key.CopyFrom(state_key)
persistent_state_key.ordered_list_user_state.ClearField("range")
available_keys = self._ordered_list_keys[self._to_key(
persistent_state_key)]

for i in list(available_keys.irange(start,
end,
inclusive=(True, False))):
del self._state[self._get_one_interval_key(persistent_state_key, i)]
available_keys.remove(i)
else:
del self._state[self._to_key(state_key)]
except KeyError:
# This may happen with the caching layer across bundles. Caching may
# skip this storage layer for a blocking_get(key) request. Without
Expand Down
195 changes: 194 additions & 1 deletion sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@

# pytype: skip-file

from __future__ import annotations

import base64
import bisect
import collections
import copy
import heapq
import itertools
import json
import logging
import random
import threading
from dataclasses import dataclass
from dataclasses import field
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
Expand All @@ -50,6 +55,8 @@

from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from sortedcontainers import SortedDict
from sortedcontainers import SortedList

import apache_beam as beam
from apache_beam import coders
Expand Down Expand Up @@ -104,7 +111,8 @@
FnApiUserRuntimeStateTypes = Union['ReadModifyWriteRuntimeState',
'CombiningValueRuntimeState',
'SynchronousSetRuntimeState',
'SynchronousBagRuntimeState']
'SynchronousBagRuntimeState',
'SynchronousOrderedListRuntimeState']

DATA_INPUT_URN = 'beam:runner:source:v1'
DATA_OUTPUT_URN = 'beam:runner:sink:v1'
Expand Down Expand Up @@ -704,6 +712,180 @@ def commit(self):
to_await.get()


class RangeSet:
"""For Internal Use only. A simple range set for ranges of [x,y)."""
def __init__(self) -> None:
# The start points and end points are stored separately in order.
self._sorted_starts = SortedList()
self._sorted_ends = SortedList()

def add(self, start: int, end: int) -> None:
if start >= end:
return

# ranges[:min_idx] and ranges[max_idx:] is unaffected by this insertion
# the first range whose end point >= the start of the new range
min_idx = self._sorted_ends.bisect_left(start)
# the first range whose start point > the end point of the new range
max_idx = self._sorted_starts.bisect_right(end)

if min_idx >= len(self._sorted_starts) or max_idx <= 0:
# the new range is beyond any current ranges
new_start = start
new_end = end
else:
# the new range overlaps with ranges[min_idx:max_idx]
new_start = min(start, self._sorted_starts[min_idx])
new_end = max(end, self._sorted_ends[max_idx - 1])

del self._sorted_starts[min_idx:max_idx]
del self._sorted_ends[min_idx:max_idx]

self._sorted_starts.add(new_start)
self._sorted_ends.add(new_end)

def __contains__(self, key: int) -> bool:
idx = self._sorted_starts.bisect_left(key)
return (idx < len(self._sorted_starts) and self._sorted_starts[idx] == key
) or (idx > 0 and self._sorted_ends[idx - 1] > key)

def __len__(self) -> int:
assert len(self._sorted_starts) == len(self._sorted_ends)
return len(self._sorted_starts)

def __iter__(self) -> Iterator[Tuple[int, int]]:
return zip(self._sorted_starts, self._sorted_ends)

def __str__(self) -> str:
return str(list(zip(self._sorted_starts, self._sorted_ends)))


class SynchronousOrderedListRuntimeState(userstate.OrderedListRuntimeState):
RANGE_MIN = -(1 << 63)
RANGE_MAX = (1 << 63) - 1
TIMESTAMP_RANGE_MIN = timestamp.Timestamp(micros=RANGE_MIN)
TIMESTAMP_RANGE_MAX = timestamp.Timestamp(micros=RANGE_MAX)

def __init__(
self,
state_handler: sdk_worker.CachingStateHandler,
state_key: beam_fn_api_pb2.StateKey,
value_coder: coders.Coder) -> None:
self._state_handler = state_handler
self._state_key = state_key
self._elem_coder = beam.coders.TupleCoder(
[coders.VarIntCoder(), coders.coders.LengthPrefixCoder(value_coder)])
self._cleared = False
self._pending_adds = SortedDict()
self._pending_removes = RangeSet()

def add(self, elem: Tuple[timestamp.Timestamp, Any]) -> None:
assert len(elem) == 2
key_ts, value = elem
key = key_ts.micros

if key >= self.RANGE_MAX or key < self.RANGE_MIN:
raise ValueError("key value %d is out of range" % key)
self._pending_adds.setdefault(key, []).append(value)

def read(self) -> Iterable[Tuple[timestamp.Timestamp, Any]]:
return self.read_range(self.TIMESTAMP_RANGE_MIN, self.TIMESTAMP_RANGE_MAX)

def read_range(
self,
min_timestamp: timestamp.Timestamp,
limit_timestamp: timestamp.Timestamp
) -> Iterable[Tuple[timestamp.Timestamp, Any]]:
# convert timestamp to int, as sort keys are stored as int internally.
min_key = min_timestamp.micros
limit_key = limit_timestamp.micros

keys_to_add = self._pending_adds.irange(
min_key, limit_key, inclusive=(True, False))

# use list interpretation here to construct the actual list
# of iterators of the selected range.
local_items = chain.from_iterable([
itertools.islice(
zip(itertools.cycle([
k,
]), self._pending_adds[k]),
len(self._pending_adds[k])) for k in keys_to_add
])

if not self._cleared:
range_query_state_key = beam_fn_api_pb2.StateKey()
range_query_state_key.CopyFrom(self._state_key)
range_query_state_key.ordered_list_user_state.range.start = min_key
range_query_state_key.ordered_list_user_state.range.end = limit_key

# make a deep copy here because there could be other operations occur in
# the middle of an iteration and change pending_removes
pending_removes_snapshot = copy.deepcopy(self._pending_removes)
persistent_items = filter(
lambda kv: kv[0] not in pending_removes_snapshot,
_StateBackedIterable(
self._state_handler, range_query_state_key, self._elem_coder))

return map(
lambda x: (timestamp.Timestamp(micros=x[0]), x[1]),
heapq.merge(persistent_items, local_items))

return map(lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), local_items)

def clear(self) -> None:
self._cleared = True
self._pending_adds = SortedDict()
self._pending_removes = RangeSet()
self._pending_removes.add(self.RANGE_MIN, self.RANGE_MAX)

def clear_range(
self,
min_timestamp: timestamp.Timestamp,
limit_timestamp: timestamp.Timestamp) -> None:
min_key = min_timestamp.micros
limit_key = limit_timestamp.micros

# materialize the keys to remove before the actual removal
keys_to_remove = list(
self._pending_adds.irange(min_key, limit_key, inclusive=(True, False)))
for k in keys_to_remove:
del self._pending_adds[k]

if not self._cleared:
self._pending_removes.add(min_key, limit_key)

def commit(self) -> None:
futures = []
if self._pending_removes:
for start, end in self._pending_removes:
range_query_state_key = beam_fn_api_pb2.StateKey()
range_query_state_key.CopyFrom(self._state_key)
range_query_state_key.ordered_list_user_state.range.start = start
range_query_state_key.ordered_list_user_state.range.end = end
futures.append(self._state_handler.clear(range_query_state_key))

self._pending_removes = RangeSet()

if self._pending_adds:
items_to_add = []
for k in self._pending_adds:
items_to_add.extend(zip(itertools.cycle([
k,
]), self._pending_adds[k]))
futures.append(
self._state_handler.extend(
self._state_key, self._elem_coder.get_impl(), items_to_add))
self._pending_adds = SortedDict()

if len(futures):
# To commit, we need to wait on every state request futures to complete.
for to_await in futures:
to_await.get()

self._cleared = False


class OutputTimer(userstate.BaseTimer):
def __init__(self,
key,
Expand Down Expand Up @@ -850,6 +1032,17 @@ def _create_state(self,
# State keys are expected in nested encoding format
key=self._key_coder.encode_nested(key))),
value_coder=state_spec.coder)
elif isinstance(state_spec, userstate.OrderedListStateSpec):
return SynchronousOrderedListRuntimeState(
self._state_handler,
state_key=beam_fn_api_pb2.StateKey(
ordered_list_user_state=beam_fn_api_pb2.StateKey.
OrderedListUserState(
transform_id=self._transform_id,
user_state_id=state_spec.name,
window=self._window_coder.encode(window),
key=self._key_coder.encode_nested(key))),
value_coder=state_spec.coder)
else:
raise NotImplementedError(state_spec)

Expand Down
Loading

0 comments on commit 7177baf

Please sign in to comment.