Skip to content

Commit

Permalink
Improve docstring and add new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Oct 18, 2024
1 parent 50dabff commit 64f2351
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 38 deletions.
19 changes: 10 additions & 9 deletions py/server/deephaven/experimental/partitioned_table_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def existing_partitions(self, table_key: TableKey,
The table should have a single row for the particular partition location key provided in the 1st argument,
with the values for the partition columns in the row.
TODO JF: This is invoked for tables created when make_table's `live` is False.
This is called for tables created when :meth:`PythonTableDataService.make_table` is called with live=False
Args:
table_key (TableKey): the table key
Expand All @@ -92,10 +92,10 @@ def subscribe_to_new_partitions(self, table_key: TableKey,
have a single row for the particular partition location key provided in the 1st argument, with the values for
the partition columns in the row.
TODO JF: This is invoked for tables created when make_table's `live` is True.
TODO: add comment if test_make_live_table_observe_subscription_cancellations demonstrates that the subscription
needs to callback for any existing partitions, too (or if existing_partitions will also be invoked when
live == True)
This is called for tables created when :meth:`PythonTableDataService.make_table` is called with live=True.
Any existing partitions created before this method is called should be passed to the callback.
Note that the callback must not be called before this method has returned.
The return value is a function that can be called to unsubscribe from the new partitions.
Expand All @@ -111,7 +111,7 @@ def partition_size(self, table_key: TableKey, table_location_key: PartitionedTab
""" Provides a callback for the backend service to pass the size of the partition with the given table key
and partition location key. The callback should be called with the size of the partition in number of rows.
TODO JF: This is invoked for tables created when make_table's `live` is False.
This is called for tables created when :meth:`PythonTableDataService.make_table` is called with live=False.
Args:
table_key (TableKey): the table key
Expand All @@ -127,9 +127,10 @@ def subscribe_to_partition_size_changes(self, table_key: TableKey, table_locatio
table key and partition location key. The callback should be called with the size of the partition in number of
rows.
TODO JF: This is invoked for tables created when make_table's `live` is True.
This callback cannot be invoked until after this method has returned.
This callback must be invoked with the initial size of the partition.
This is called for tables created when :meth:`PythonTableDataService.make_table` is called with live=True
Note that the callback must be called with the initial size of the partition after this method has returned and
must not be called before this method has returned.
The return value is a function that can be called to unsubscribe from the partition size changes.
Expand Down
78 changes: 49 additions & 29 deletions py/server/tests/test_partitioned_table_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
import threading
import time
import unittest
from typing import Callable, Tuple, Optional, Generator
from typing import Callable, Tuple, Optional, Generator, List

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc

from deephaven import new_table
from deephaven import new_table, garbage_collect
from deephaven.column import byte_col, char_col, short_col, int_col, long_col, float_col, double_col, string_col, \
datetime_col, bool_col, ColumnType
from deephaven.execution_context import get_exec_ctx, ExecutionContext
from deephaven.experimental.partitioned_table_service import PartitionedTableServiceBackend, TableKey, \
PartitionedTableLocationKey, PythonTableDataService
import deephaven.arrow as dharrow
from deephaven.liveness_scope import liveness_scope

from tests.testbase import BaseTestCase

Expand All @@ -27,9 +28,11 @@ def __init__(self, gen_pa_table: Generator[pa.Table, None, None], pt_schema: pa.
self.pt_schema = pt_schema
self.pc_schema = pc_schema
self.gen_pa_table = gen_pa_table
self._sub_new_partition_cancelled = False
self._partitions: dict[PartitionedTableLocationKey, pa.Table] = {}
self._partitions_size_subscriptions: dict[PartitionedTableLocationKey, bool] = {}
self.sub_new_partition_cancelled = False
self.partitions: dict[PartitionedTableLocationKey, pa.Table] = {}
self.partitions_size_subscriptions: dict[PartitionedTableLocationKey, bool] = {}
self.existing_partitions_called = 0
self.partition_size_called = 0

def table_schema(self, table_key: TableKey) -> Tuple[pa.Schema, Optional[pa.Schema]]:
if table_key.key == "test":
Expand All @@ -42,27 +45,29 @@ def existing_partitions(self, table_key: TableKey, callback: Callable[[Partition
ticker = str(pa_table.column("Ticker")[0])

partition_key = PartitionedTableLocationKey(f"{ticker}/NYSE")
self._partitions[partition_key] = pa_table
self.partitions[partition_key] = pa_table

expr = ((pc.field("Ticker") == f"{ticker}") & (pc.field("Exchange") == "NYSE"))
callback(partition_key, pa_table.filter(expr).select(["Ticker", "Exchange"]).slice(0, 1))
self.existing_partitions_called += 1

def partition_size(self, table_key: TableKey, table_location_key: PartitionedTableLocationKey,
callback: Callable[[int], None]) -> None:
callback(self._partitions[table_location_key].num_rows)
callback(self.partitions[table_location_key].num_rows)
self.partition_size_called += 1

def column_values(self, table_key: TableKey, table_location_key: PartitionedTableLocationKey,
col: str, offset: int, min_rows: int, max_rows: int) -> pa.Table:
if table_key.key == "test":
return self._partitions[table_location_key].select([col]).slice(offset, max_rows)
return self.partitions[table_location_key].select([col]).slice(offset, max_rows)
else:
return pa.table([])

def _th_new_partitions(self, table_key: TableKey, exec_ctx: ExecutionContext, callback: Callable[[PartitionedTableLocationKey, Optional[pa.Table]], None]) -> None:
if table_key.key != "test":
return

while not self._sub_new_partition_cancelled:
while not self.sub_new_partition_cancelled:
try:
with exec_ctx:
pa_table = next(self.gen_pa_table)
Expand All @@ -71,7 +76,7 @@ def _th_new_partitions(self, table_key: TableKey, exec_ctx: ExecutionContext, ca

ticker = str(pa_table.column("Ticker")[0])
partition_key = PartitionedTableLocationKey(f"{ticker}/NYSE")
self._partitions[partition_key] = pa_table
self.partitions[partition_key] = pa_table

expr = ((pc.field("Ticker") == f"{ticker}") & (pc.field("Exchange") == "NYSE"))
callback(partition_key, pa_table.filter(expr).select(["Ticker", "Exchange"]).slice(0, 1))
Expand All @@ -81,15 +86,12 @@ def subscribe_to_new_partitions(self, table_key: TableKey, callback) -> Callable
if table_key.key != "test":
return lambda: None

# TODO for test count the number opened subscriptions

exec_ctx = get_exec_ctx()
th = threading.Thread(target=self._th_new_partitions, args=(table_key, exec_ctx, callback))
th.start()

def _cancellation_callback():
# TODO for test count the number cancellations
self._sub_new_partition_cancelled = True
self.sub_new_partition_cancelled += 1

return _cancellation_callback

Expand All @@ -98,15 +100,15 @@ def _th_partition_size_changes(self, table_key: TableKey, table_location_key: Pa
if table_key.key != "test":
return

if table_location_key not in self._partitions_size_subscriptions:
if table_location_key not in self.partitions_size_subscriptions:
return

while self._partitions_size_subscriptions[table_location_key]:
pa_table = self._partitions[table_location_key]
while self.partitions_size_subscriptions[table_location_key]:
pa_table = self.partitions[table_location_key]
rbs = pa_table.to_batches()
rbs.append(pa_table.to_batches()[0])
new_pa_table = pa.Table.from_batches(rbs)
self._partitions[table_location_key] = new_pa_table
self.partitions[table_location_key] = new_pa_table
callback(new_pa_table.num_rows)
time.sleep(0.1)

Expand All @@ -117,16 +119,15 @@ def subscribe_to_partition_size_changes(self, table_key: TableKey,
if table_key.key != "test":
return lambda: None

if table_location_key not in self._partitions:
if table_location_key not in self.partitions:
return lambda: None

self._partitions_size_subscriptions[table_location_key] = True
self.partitions_size_subscriptions[table_location_key] = True
th = threading.Thread(target=self._th_partition_size_changes, args=(table_key, table_location_key, callback))
th.start()

# TODO count number of total subscriptions and number of total cancellations
def _cancellation_callback():
self._partitions_size_subscriptions[table_location_key] = False
self.partitions_size_subscriptions[table_location_key] = False

return _cancellation_callback

Expand Down Expand Up @@ -177,6 +178,8 @@ def test_make_static_table_with_partition_schema(self):
self.assertTrue(table.columns[1].column_type == ColumnType.PARTITIONING)
self.assertEqual(table.columns[2:], self.test_table.columns[2:])
self.assertEqual(table.size, 2)
self.assertEqual(backend.existing_partitions_called, 1)
self.assertEqual(backend.partition_size_called, 1)
# how is the table different from the PartitionedTable?

def test_make_live_table_with_partition_schema(self):
Expand All @@ -191,6 +194,8 @@ def test_make_live_table_with_partition_schema(self):
self.assertEqual(table.columns[2:], self.test_table.columns[2:])
self.wait_ticking_table_update(table, 20, 5)
self.assertGreaterEqual(table.size, 20)
self.assertEqual(backend.existing_partitions_called, 0)
self.assertEqual(backend.partition_size_called, 0)

def test_make_live_table_with_partition_schema_ops(self):
pc_schema = pa.schema(
Expand All @@ -211,18 +216,33 @@ def test_make_live_table_with_partition_schema_ops(self):
self.assertEqual(t.columns, self.test_table.columns)

def test_make_live_table_observe_subscription_cancellations(self):
# coalesce the PartitionAwareSourceTable under a liveness scope
# count number of new partition subscriptions
# count number of partition size subscriptions
# close liveness scope
# assert subscriptions are all closed
pass
pc_schema = pa.schema(
[pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())])
backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema)
data_service = PythonTableDataService(backend)
with liveness_scope():
table = data_service.make_table(TableKey("test"), live=True)
self.wait_ticking_table_update(table, 100, 5)
# table = None
#
# garbage_collect()
# time.sleep(10)
# print(backend.partitions_size_subscriptions.values())
self.assertEqual(backend.sub_new_partition_cancelled, 1)
self.assertFalse(all(backend.partitions_size_subscriptions.values()))

def test_make_live_table_ensure_initial_partitions_exist(self):
# disable new partition subscriptions
# coalesce the PartitionAwareSourceTable
# ensure that all existing partitions were added to the table
pass
pc_schema = pa.schema(
[pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())])
backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema)
backend.sub_new_partition_cancelled = True
data_service = PythonTableDataService(backend)
table = data_service.make_table(TableKey("test"), live=True)
table.coalesce()
self.assertEqual(backend.existing_partitions_called, 0)

if __name__ == '__main__':
unittest.main()

0 comments on commit 64f2351

Please sign in to comment.