diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index c4edaa85a89d..63bd5651def0 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,3 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run" } + diff --git a/CHANGES.md b/CHANGES.md index 941ba23a7573..a02be953d11e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ ## New Features / Improvements * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Bigtable enrichment handler now accepts a custom function to build a composite row key. (Python) ([#30974](https://github.com/apache/beam/issues/30975)). ## Breaking Changes diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index af35f91a42f3..ddb62c2f60d5 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -16,6 +16,7 @@ # import logging from typing import Any +from typing import Callable from typing import Dict from typing import Optional @@ -33,6 +34,8 @@ 'BigTableEnrichmentHandler', ] +RowKeyFn = Callable[[beam.Row], bytes] + _LOGGER = logging.getLogger(__name__) @@ -53,6 +56,8 @@ class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): See https://cloud.google.com/bigtable/docs/app-profiles for more details. encoding (str): encoding type to convert the string to bytes and vice-versa from BigTable. Default is `utf-8`. + row_key_fn: a lambda function that returns a string row key from the + input row. It is used to build/extract the row key for Bigtable. exception_level: a `enum.Enum` value from ``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`` to set the level when an empty row is returned from the BigTable query. @@ -67,11 +72,12 @@ def __init__( project_id: str, instance_id: str, table_id: str, - row_key: str, + row_key: str = "", row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1), *, app_profile_id: str = None, # type: ignore[assignment] encoding: str = 'utf-8', + row_key_fn: Optional[RowKeyFn] = None, exception_level: ExceptionLevel = ExceptionLevel.WARN, include_timestamp: bool = False, ): @@ -82,8 +88,15 @@ def __init__( self._row_filter = row_filter self._app_profile_id = app_profile_id self._encoding = encoding + self._row_key_fn = row_key_fn self._exception_level = exception_level self._include_timestamp = include_timestamp + if ((not self._row_key_fn and not self._row_key) or + bool(self._row_key_fn and self._row_key)): + raise ValueError( + "Please specify exactly one of `row_key` or a lambda " + "function with `row_key_fn` to extract the row key " + "from the input row.") def __enter__(self): """connect to the Google BigTable cluster.""" @@ -105,9 +118,12 @@ def __call__(self, request: beam.Row, *args, **kwargs): response_dict: Dict[str, Any] = {} row_key_str: str = "" try: - request_dict = request._asdict() - row_key_str = str(request_dict[self._row_key]) - row_key = row_key_str.encode(self._encoding) + if self._row_key_fn: + row_key = self._row_key_fn(request) + else: + request_dict = request._asdict() + row_key_str = str(request_dict[self._row_key]) + row_key = row_key_str.encode(self._encoding) row = self._table.read_row(row_key, filter_=self._row_filter) if row: for cf_id, cf_v in row.cells.items(): @@ -148,4 +164,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_cache_key(self, request: beam.Row) -> str: """Returns a string formatted with row key since it is unique to a request made to `Bigtable`.""" + if self._row_key_fn: + return "row_key: %s" % str(self._row_key_fn(request)) return "%s: %s" % (self._row_key, request._asdict()[self._row_key]) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index bbabbc306e61..79d73178e94e 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -46,6 +46,11 @@ _LOGGER = logging.getLogger(__name__) +def _row_key_fn(request: beam.Row) -> bytes: + row_key = str(request.product_id) # type: ignore[attr-defined] + return row_key.encode(encoding='utf-8') + + class ValidateResponse(beam.DoFn): """ValidateResponse validates if a PCollection of `beam.Row` has the required fields.""" @@ -426,6 +431,29 @@ def test_bigtable_enrichment_with_redis(self): expected_enriched_fields))) BigTableEnrichmentHandler.__call__ = actual + def test_bigtable_enrichment_with_lambda(self): + expected_fields = [ + 'sale_id', 'customer_id', 'product_id', 'quantity', 'product' + ] + expected_enriched_fields = { + 'product': ['product_id', 'product_name', 'product_stock'], + } + bigtable = BigTableEnrichmentHandler( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key_fn=_row_key_fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ BigTable" >> Enrichment(bigtable) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_test.py new file mode 100644 index 000000000000..1c5cb4064e0e --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_test.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from parameterized import parameterized + +try: + from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.bigtable_it_test import _row_key_fn +except ImportError: + raise unittest.SkipTest('Bigtable test dependencies are not installed.') + + +class TestBigTableEnrichmentHandler(unittest.TestCase): + @parameterized.expand([('product_id', _row_key_fn), ('', None)]) + def test_bigtable_enrichment_invalid_args(self, row_key, row_key_fn): + with self.assertRaises(ValueError): + _ = BigTableEnrichmentHandler( + project_id='apache-beam-testing', + instance_id='beam-test', + table_id='bigtable-enrichment-test', + row_key=row_key, + row_key_fn=row_key_fn) + + +if __name__ == '__main__': + unittest.main()