Skip to content

Commit

Permalink
allow lambda function in bigtable handler
Browse files Browse the repository at this point in the history
  • Loading branch information
riteshghorse committed Apr 15, 2024
1 parent cd253fd commit bd28cf7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
27 changes: 22 additions & 5 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
import logging
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional

Expand All @@ -33,6 +34,8 @@
'BigTableEnrichmentHandler',
]

RowKeyFn = Callable[[beam.Row], str]

_LOGGER = logging.getLogger(__name__)


Expand All @@ -45,14 +48,17 @@ class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
instance_id (str): GCP instance-id of the BigTable cluster.
table_id (str): GCP table-id of the BigTable.
row_key (str): unique row-key field name from the input `beam.Row` object
to use as `row_key` for BigTable querying.
to use as `row_key` for BigTable querying. This parameter will be ignored
if a lambda function is specified with `row_key_fn`.
row_filter: a ``:class:`google.cloud.bigtable.row_filters.RowFilter``` to
filter data read with ``read_row()``.
Defaults to `CellsColumnLimitFilter(1)`.
app_profile_id (str): App profile ID to use for BigTable.
See https://cloud.google.com/bigtable/docs/app-profiles for more details.
encoding (str): encoding type to convert the string to bytes and vice-versa
from BigTable. Default is `utf-8`.
row_key_fn: a lambda function that returns a string row key from the
input row. It is used to build/extract the row key for Bigtable.
exception_level: a `enum.Enum` value from
``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel``
to set the level when an empty row is returned from the BigTable query.
Expand All @@ -67,11 +73,12 @@ def __init__(
project_id: str,
instance_id: str,
table_id: str,
row_key: str,
row_key: str = "",
row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1),
*,
app_profile_id: str = None, # type: ignore[assignment]
encoding: str = 'utf-8',
row_key_fn: Optional[RowKeyFn] = None,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
include_timestamp: bool = False,
):
Expand All @@ -82,8 +89,13 @@ def __init__(
self._row_filter = row_filter
self._app_profile_id = app_profile_id
self._encoding = encoding
self._row_key_fn = row_key_fn
self._exception_level = exception_level
self._include_timestamp = include_timestamp
if not bool(self._row_key_fn or self._row_key):
raise ValueError(
"Please specify either `row_key` or a lambda function "
"with `row_key_fn` to extract row key from input row.")

def __enter__(self):
"""connect to the Google BigTable cluster."""
Expand All @@ -105,9 +117,12 @@ def __call__(self, request: beam.Row, *args, **kwargs):
response_dict: Dict[str, Any] = {}
row_key_str: str = ""
try:
request_dict = request._asdict()
row_key_str = str(request_dict[self._row_key])
row_key = row_key_str.encode(self._encoding)
if self._row_key_fn:
row_key = str(self._row_key_fn(request))
else:
request_dict = request._asdict()
row_key_str = str(request_dict[self._row_key])
row_key = row_key_str.encode(self._encoding)
row = self._table.read_row(row_key, filter_=self._row_filter)
if row:
for cf_id, cf_v in row.cells.items():
Expand Down Expand Up @@ -148,4 +163,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def get_cache_key(self, request: beam.Row) -> str:
"""Returns a string formatted with row key since it is unique to
a request made to `Bigtable`."""
if self._row_key_fn:
return "row_key: %s" % self._row_key_fn(request)
return "%s: %s" % (self._row_key, request._asdict()[self._row_key])
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
_LOGGER = logging.getLogger(__name__)


def _row_key_fn(request: beam.Row) -> str:
return request.product_id


class ValidateResponse(beam.DoFn):
"""ValidateResponse validates if a PCollection of `beam.Row`
has the required fields."""
Expand Down Expand Up @@ -426,6 +430,29 @@ def test_bigtable_enrichment_with_redis(self):
expected_enriched_fields)))
BigTableEnrichmentHandler.__call__ = actual

def test_bigtable_enrichment_with_lambda(self):
expected_fields = [
'sale_id', 'customer_id', 'product_id', 'quantity', 'product'
]
expected_enriched_fields = {
'product': ['product_id', 'product_name', 'product_stock'],
}
bigtable = BigTableEnrichmentHandler(
project_id=self.project_id,
instance_id=self.instance_id,
table_id=self.table_id,
row_key_fn=_row_key_fn)
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| "Create" >> beam.Create(self.req)
| "Enrich W/ BigTable" >> Enrichment(bigtable)
| "Validate Response" >> beam.ParDo(
ValidateResponse(
len(expected_fields),
expected_fields,
expected_enriched_fields)))


if __name__ == '__main__':
unittest.main()

0 comments on commit bd28cf7

Please sign in to comment.