From 4b351c2e87f163845799029480a77aabc7c9a3be Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Mon, 15 Apr 2024 13:51:05 -0400 Subject: [PATCH] allow lambda function in bigtable handler --- CHANGES.md | 1 + .../enrichment_handlers/bigtable.py | 27 ++++++++++++++---- .../enrichment_handlers/bigtable_it_test.py | 28 +++++++++++++++++++ 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 941ba23a7573c..a02be953d11e5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ ## New Features / Improvements * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Bigtable enrichment handler now accepts a custom function to build a composite row key. (Python) ([#30974](https://github.com/apache/beam/issues/30975)). ## Breaking Changes diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index af35f91a42f34..896a5297b576c 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -16,6 +16,7 @@ # import logging from typing import Any +from typing import Callable from typing import Dict from typing import Optional @@ -33,6 +34,8 @@ 'BigTableEnrichmentHandler', ] +RowKeyFn = Callable[[beam.Row], bytes] + _LOGGER = logging.getLogger(__name__) @@ -45,7 +48,8 @@ class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): instance_id (str): GCP instance-id of the BigTable cluster. table_id (str): GCP table-id of the BigTable. row_key (str): unique row-key field name from the input `beam.Row` object - to use as `row_key` for BigTable querying. + to use as `row_key` for BigTable querying. This parameter will be ignored + if a lambda function is specified with `row_key_fn`. row_filter: a ``:class:`google.cloud.bigtable.row_filters.RowFilter``` to filter data read with ``read_row()``. Defaults to `CellsColumnLimitFilter(1)`. @@ -53,6 +57,8 @@ class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): See https://cloud.google.com/bigtable/docs/app-profiles for more details. encoding (str): encoding type to convert the string to bytes and vice-versa from BigTable. Default is `utf-8`. + row_key_fn: a lambda function that returns a string row key from the + input row. It is used to build/extract the row key for Bigtable. exception_level: a `enum.Enum` value from ``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`` to set the level when an empty row is returned from the BigTable query. @@ -67,11 +73,12 @@ def __init__( project_id: str, instance_id: str, table_id: str, - row_key: str, + row_key: str = "", row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1), *, app_profile_id: str = None, # type: ignore[assignment] encoding: str = 'utf-8', + row_key_fn: Optional[RowKeyFn] = None, exception_level: ExceptionLevel = ExceptionLevel.WARN, include_timestamp: bool = False, ): @@ -82,8 +89,13 @@ def __init__( self._row_filter = row_filter self._app_profile_id = app_profile_id self._encoding = encoding + self._row_key_fn = row_key_fn self._exception_level = exception_level self._include_timestamp = include_timestamp + if not bool(self._row_key_fn or self._row_key): + raise ValueError( + "Please specify either `row_key` or a lambda function " + "with `row_key_fn` to extract row key from input row.") def __enter__(self): """connect to the Google BigTable cluster.""" @@ -105,9 +117,12 @@ def __call__(self, request: beam.Row, *args, **kwargs): response_dict: Dict[str, Any] = {} row_key_str: str = "" try: - request_dict = request._asdict() - row_key_str = str(request_dict[self._row_key]) - row_key = row_key_str.encode(self._encoding) + if self._row_key_fn: + row_key = self._row_key_fn(request) + else: + request_dict = request._asdict() + row_key_str = str(request_dict[self._row_key]) + row_key = row_key_str.encode(self._encoding) row = self._table.read_row(row_key, filter_=self._row_filter) if row: for cf_id, cf_v in row.cells.items(): @@ -148,4 +163,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_cache_key(self, request: beam.Row) -> str: """Returns a string formatted with row key since it is unique to a request made to `Bigtable`.""" + if self._row_key_fn: + return "row_key: %s" % self._row_key_fn(request) return "%s: %s" % (self._row_key, request._asdict()[self._row_key]) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index bbabbc306e611..79d73178e94e1 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -46,6 +46,11 @@ _LOGGER = logging.getLogger(__name__) +def _row_key_fn(request: beam.Row) -> bytes: + row_key = str(request.product_id) # type: ignore[attr-defined] + return row_key.encode(encoding='utf-8') + + class ValidateResponse(beam.DoFn): """ValidateResponse validates if a PCollection of `beam.Row` has the required fields.""" @@ -426,6 +431,29 @@ def test_bigtable_enrichment_with_redis(self): expected_enriched_fields))) BigTableEnrichmentHandler.__call__ = actual + def test_bigtable_enrichment_with_lambda(self): + expected_fields = [ + 'sale_id', 'customer_id', 'product_id', 'quantity', 'product' + ] + expected_enriched_fields = { + 'product': ['product_id', 'product_name', 'product_stock'], + } + bigtable = BigTableEnrichmentHandler( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key_fn=_row_key_fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ BigTable" >> Enrichment(bigtable) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + if __name__ == '__main__': unittest.main()