From ca86df1db6623df6856404e07f49185740bffc94 Mon Sep 17 00:00:00 2001 From: Anand Inguva <34158215+AnandInguva@users.noreply.github.com> Date: Thu, 6 Jul 2023 13:21:54 -0400 Subject: [PATCH] MLTransform (#26795) * Initial work on MLTransform and ProcessHandler * Support for containers: List, Dict[str, np.ndarray] pass types Support Pyarrow schema Artifact WIP * Add min, max, artifacts for scale_0_to_1 * Add more transform functions and artifacts WIP on inferring types Remove pyarrow implementation Add MLTransformOutput Refactor files * Add generic type annotations * Add unit tests Fix artifacts code Add more tests fix lint erors Change namespaces from ml_transform to transforms Add doc strings Add tests and refactor * Add support for saving intermediate results for a transform Sort imports Add metrics namespaces Refactor * Add schema to the output PCollection * Remove MLTransformOutput and return Row instead with schema * Convert primitive type to list using a DoFn. Remove FixedLenFeatureSpec Make VarLenFeatureSpec as default Refactoring * Add append_transform to the ProcessHandler Some more refactoring * Remove param self.has_artifacts, add artifact_location to handler..and address PR comments Add skip conditions for tests Add test suite for tft tests * Move tensorflow import into the try except catch Try except in __init__.py Remove imports from __init__ Add docstrings, refactor * Add type annotations for the data transforms * Add tft test in tox.ini Mock tensorflow_transform in pydocs fix tft pypi name Skip a test Add step name Update supported versions of TFT * Add step name for TFTProcessHandler * Remove unsupported tft versions * Fix mypy * Refactor TFTProcessHandlerDict to TFTProcessHandlerSchema * Update doc for data processing transforms * Fix checking the typing container types * Refactor code * Fail TFTProcessHandler on a non-global window PColl * Remove underscore * Remove high level functions * Add TFIDF * Fix tests with new changes[WIP] * Fix tests * Refactor class name to CamelCase and remove kwrags * use is_default instead of isinstance * Remove falling back to staging location for artifact location * Add TFIDF tests * Remove __str__ * Refactor skip statement * Add utils for fetching artifacts on compute and apply vocab * Make ProcessHandler internal class * Only run analyze stage when transform_fn(artifacts) is not computed before. * Fail if pipeline has non default window during artifact producing stage * Add support for Dict, recordbatch and introduce artifact_mode * Hide process_handler from user. Make TFTProcessHandler as default * Refactor few tests * Comment a test * Save raw_data_meta_data so that it can be used during consume stage * Refactor code * Add test on artifacts * Fix imports * Add tensorflow_metadata to pydocs * Fix test * Add TFIDF to import * Add basic example * Remove redundant logging statements * Add test for multiple columns on MLTransform * Add todo about what to do when new process handler is introduced * Add abstractmethod decorator * Edit Error message * Update docs, error messages * Remove record batch input/output arg * Modify generic types * Fix import sort * Fix mypy errors - best effort * Fix tests * Add TFTOperation doc * Rename tft_transform to tft * Fix hadler_test * Fix base_test * Fix pydocs --- .../ml_transform/ml_transform_basic.py | 118 +++++ .../apache_beam/ml/transforms/__init__.py | 16 + sdks/python/apache_beam/ml/transforms/base.py | 165 +++++++ .../apache_beam/ml/transforms/base_test.py | 246 ++++++++++ .../apache_beam/ml/transforms/handlers.py | 410 ++++++++++++++++ .../ml/transforms/handlers_test.py | 355 ++++++++++++++ sdks/python/apache_beam/ml/transforms/tft.py | 440 ++++++++++++++++++ .../apache_beam/ml/transforms/tft_test.py | 395 ++++++++++++++++ .../python/apache_beam/ml/transforms/utils.py | 56 +++ sdks/python/scripts/generate_pydoc.sh | 2 +- sdks/python/test-suites/tox/py38/build.gradle | 4 + sdks/python/tox.ini | 6 + 12 files changed, 2212 insertions(+), 1 deletion(-) create mode 100644 sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py create mode 100644 sdks/python/apache_beam/ml/transforms/__init__.py create mode 100644 sdks/python/apache_beam/ml/transforms/base.py create mode 100644 sdks/python/apache_beam/ml/transforms/base_test.py create mode 100644 sdks/python/apache_beam/ml/transforms/handlers.py create mode 100644 sdks/python/apache_beam/ml/transforms/handlers_test.py create mode 100644 sdks/python/apache_beam/ml/transforms/tft.py create mode 100644 sdks/python/apache_beam/ml/transforms/tft_test.py create mode 100644 sdks/python/apache_beam/ml/transforms/utils.py diff --git a/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py new file mode 100644 index 000000000000..65e943f5c697 --- /dev/null +++ b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py @@ -0,0 +1,118 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This example demonstrates how to use MLTransform. +MLTransform is a PTransform that applies multiple data transformations on the +incoming data. + +This example computes the vocabulary on the incoming data. Then, it computes +the TF-IDF of the incoming data using the vocabulary computed in the previous +step. + +1. ComputeAndApplyVocabulary computes the vocabulary on the incoming data and + overrides the incoming data with the vocabulary indices. +2. TFIDF computes the TF-IDF of the incoming data using the vocabulary and + provides vocab_index and tf-idf weights. vocab_index is suffixed with + '_vocab_index' and tf-idf weights are suffixed with '_tfidf' to the + original column name(which is the output of ComputeAndApplyVocabulary). + +MLTransform produces artifacts, for example: ComputeAndApplyVocabulary produces +a text file that contains vocabulary which is saved in `artifact_location`. +ComputeAndApplyVocabulary outputs vocab indices associated with the saved vocab +file. This mode of MLTransform is called artifact `produce` mode. +This will be useful when the data is preprocessed before ML model training. + +The second mode of MLTransform is artifact `consume` mode. In this mode, the +transformations are applied on the incoming data using the artifacts produced +by the previous run of MLTransform. This mode will be useful when the data is +preprocessed before ML model inference. +""" + +import argparse +import logging +import tempfile + +import apache_beam as beam +from apache_beam.ml.transforms.base import ArtifactMode +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import TFIDF +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.ml.transforms.utils import ArtifactsFetcher + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--artifact_location', type=str, default='') + return parser.parse_known_args() + + +def run(args): + data = [ + dict(x=["Let's", "go", "to", "the", "park"]), + dict(x=["I", "enjoy", "going", "to", "the", "park"]), + dict(x=["I", "enjoy", "reading", "books"]), + dict(x=["Beam", "can", "be", "fun"]), + dict(x=["The", "weather", "is", "really", "nice", "today"]), + dict(x=["I", "love", "to", "go", "to", "the", "park"]), + dict(x=["I", "love", "to", "read", "books"]), + dict(x=["I", "love", "to", "program"]), + ] + + with beam.Pipeline() as p: + input_data = p | beam.Create(data) + + # arfifacts produce mode. + input_data |= ( + 'MLTransform' >> MLTransform( + artifact_location=args.artifact_location, + artifact_mode=ArtifactMode.PRODUCE, + ).with_transform(ComputeAndApplyVocabulary( + columns=['x'])).with_transform(TFIDF(columns=['x']))) + + # _ = input_data | beam.Map(logging.info) + + with beam.Pipeline() as p: + input_data = [ + dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam']) + ] + input_data = p | beam.Create(input_data) + + # artifacts consume mode. + input_data |= ( + MLTransform( + artifact_location=args.artifact_location, + artifact_mode=ArtifactMode.CONSUME, + # you don't need to specify transforms as they are already saved in + # in the artifacts. + )) + + _ = input_data | beam.Map(logging.info) + + # To fetch the artifacts after the pipeline is run + artifacts_fetcher = ArtifactsFetcher(artifact_location=args.artifact_location) + vocab_list = artifacts_fetcher.get_vocab_list() + assert vocab_list[22] == 'Beam' + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + args, pipeline_args = parse_args() + # for this example, create a temp artifact location if not provided. + if args.artifact_location == '': + args.artifact_location = tempfile.mkdtemp() + run(args) diff --git a/sdks/python/apache_beam/ml/transforms/__init__.py b/sdks/python/apache_beam/ml/transforms/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py new file mode 100644 index 000000000000..f29064094844 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -0,0 +1,165 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: skip-file + +import abc +from typing import Generic +from typing import List +from typing import Optional +from typing import Sequence +from typing import TypeVar + +import apache_beam as beam + +__all__ = ['MLTransform', 'ProcessHandler', 'BaseOperation'] + +TransformedDatasetT = TypeVar('TransformedDatasetT') +TransformedMetadataT = TypeVar('TransformedMetadataT') + +# Input/Output types to the MLTransform. +ExampleT = TypeVar('ExampleT') +MLTransformOutputT = TypeVar('MLTransformOutputT') + +# Input to the apply() method of BaseOperation. +OperationInputT = TypeVar('OperationInputT') +# Output of the apply() method of BaseOperation. +OperationOutputT = TypeVar('OperationOutputT') + + +class ArtifactMode(object): + PRODUCE = 'produce' + CONSUME = 'consume' + + +class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC): + def __init__(self, columns: List[str]) -> None: + """ + Base Opertation class data processing transformations. + Args: + columns: List of column names to apply the transformation. + """ + self.columns = columns + + @abc.abstractmethod + def apply( + self, data: OperationInputT, output_column_name: str) -> OperationOutputT: + """ + Define any processing logic in the apply() method. + processing logics are applied on inputs and returns a transformed + output. + Args: + inputs: input data. + """ + + +class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC): + """ + Only for internal use. No backwards compatibility guarantees. + """ + @abc.abstractmethod + def process_data( + self, pcoll: beam.PCollection[ExampleT] + ) -> beam.PCollection[MLTransformOutputT]: + """ + Logic to process the data. This will be the entrypoint in + beam.MLTransform to process incoming data. + """ + + @abc.abstractmethod + def append_transform(self, transform: BaseOperation): + """ + Append transforms to the ProcessHandler. + """ + + +class MLTransform(beam.PTransform[beam.PCollection[ExampleT], + beam.PCollection[MLTransformOutputT]], + Generic[ExampleT, MLTransformOutputT]): + def __init__( + self, + *, + artifact_location: str, + artifact_mode: str = ArtifactMode.PRODUCE, + transforms: Optional[Sequence[BaseOperation]] = None): + """ + Args: + artifact_location: A storage location for artifacts resulting from + MLTransform. These artifacts include transformations applied to + the dataset and generated values like min, max from ScaleTo01, + and mean, var from ScaleToZScore. Artifacts are produced and stored + in this location when the `artifact_mode` is set to 'produce'. + Conversely, when `artifact_mode` is set to 'consume', artifacts are + retrieved from this location. Note that when consuming artifacts, + it is not necessary to pass the transforms since they are inherently + stored within the artifacts themselves. The value assigned to + `artifact_location` should be a valid storage path where the artifacts + can be written to or read from. + transforms: A list of transforms to apply to the data. All the transforms + are applied in the order they are specified. The input of the + i-th transform is the output of the (i-1)-th transform. Multi-input + transforms are not supported yet. + artifact_mode: Whether to produce or consume artifacts. If set to + 'consume', the handler will assume that the artifacts are already + computed and stored in the artifact_location. Pass the same artifact + location that was passed during produce phase to ensure that the + right artifacts are read. If set to 'produce', the handler + will compute the artifacts and store them in the artifact_location. + The artifacts will be read from this location during the consume phase. + There is no need to pass the transforms in this case since they are + already embedded in the stored artifacts. + """ + # avoid circular import + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam.ml.transforms.handlers import TFTProcessHandler + # TODO: When new ProcessHandlers(eg: JaxProcessHandler) are introduced, + # create a mapping between transforms and ProcessHandler since + # ProcessHandler is not exposed to the user. + process_handler: ProcessHandler = TFTProcessHandler( + artifact_location=artifact_location, + artifact_mode=artifact_mode, + transforms=transforms) # type: ignore[arg-type] + + self._process_handler = process_handler + + def expand( + self, pcoll: beam.PCollection[ExampleT] + ) -> beam.PCollection[MLTransformOutputT]: + """ + This is the entrypoint for the MLTransform. This method will + invoke the process_data() method of the ProcessHandler instance + to process the incoming data. + + process_data takes in a PCollection and applies the PTransforms + necessary to process the data and returns a PCollection of + transformed data. + Args: + pcoll: A PCollection of ExampleT type. + Returns: + A PCollection of MLTransformOutputT type. + """ + return self._process_handler.process_data(pcoll) + + def with_transform(self, transform: BaseOperation): + """ + Add a transform to the MLTransform pipeline. + Args: + transform: A BaseOperation instance. + Returns: + A MLTransform instance. + """ + self._process_handler.append_transform(transform) + return self diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py new file mode 100644 index 000000000000..be208c934269 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -0,0 +1,246 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import shutil +import tempfile +import typing +import unittest +from typing import List + +import numpy as np +from parameterized import param +from parameterized import parameterized + +import apache_beam as beam +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports +try: + from apache_beam.ml.transforms import base + from apache_beam.ml.transforms import tft + from apache_beam.ml.transforms.tft import TFTOperation +except ImportError: + tft = None # type: ignore + +if tft is None: + raise unittest.SkipTest('tensorflow_transform is not installed') + + +class _FakeOperation(TFTOperation): + def __init__(self, name, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + + def apply(self, inputs, output_column_name, **kwargs): + return {output_column_name: inputs} + + +class BaseMLTransformTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_ml_transform_appends_transforms_to_process_handler_correctly(self): + fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x']) + transforms = [fake_fn_1] + ml_transform = base.MLTransform( + transforms=transforms, artifact_location=self.artifact_location) + ml_transform = ml_transform.with_transform( + transform=_FakeOperation(name='fake_fn_2', columns=['x'])) + + self.assertEqual(len(ml_transform._process_handler.transforms), 2) + self.assertEqual( + ml_transform._process_handler.transforms[0].name, 'fake_fn_1') + self.assertEqual( + ml_transform._process_handler.transforms[1].name, 'fake_fn_2') + + def test_ml_transform_on_unbatched_dict(self): + transforms = [tft.ScaleTo01(columns=['x'])] + unbatched_data = [{'x': 1}, {'x': 2}] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(unbatched_data) + | base.MLTransform( + artifact_location=self.artifact_location, transforms=transforms)) + expected_output = [ + np.array([0.0], dtype=np.float32), + np.array([1.0], dtype=np.float32), + ] + actual_output = result | beam.Map(lambda x: x.x) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + def test_ml_transform_on_batched_dict(self): + transforms = [tft.ScaleTo01(columns=['x'])] + batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + batched_result = ( + p + | beam.Create(batched_data) + | base.MLTransform( + transforms=transforms, artifact_location=self.artifact_location)) + expected_output = [ + np.array([0, 0.2, 0.4], dtype=np.float32), + np.array([0.6, 0.8, 1], dtype=np.float32), + ] + actual_output = batched_result | beam.Map(lambda x: x.x) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + @parameterized.expand([ + param( + input_data=[{ + 'x': 1, + 'y': 2.0, + }], + input_types={ + 'x': int, 'y': float + }, + expected_dtype={ + 'x': typing.Sequence[np.float32], + 'y': typing.Sequence[np.float32], + }, + ), + param( + input_data=[{ + 'x': np.array([1], dtype=np.int64), + 'y': np.array([2.0], dtype=np.float32), + }], + input_types={ + 'x': np.int32, 'y': np.float32 + }, + expected_dtype={ + 'x': typing.Sequence[np.float32], + 'y': typing.Sequence[np.float32], + }, + ), + param( + input_data=[{ + 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] + }], + input_types={ + 'x': List[int], 'y': List[float] + }, + expected_dtype={ + 'x': typing.Sequence[np.float32], + 'y': typing.Sequence[np.float32], + }, + ), + param( + input_data=[{ + 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] + }], + input_types={ + 'x': typing.Sequence[int], + 'y': typing.Sequence[float], + }, + expected_dtype={ + 'x': typing.Sequence[np.float32], + 'y': typing.Sequence[np.float32], + }, + ), + ]) + def test_ml_transform_dict_output_pcoll_schema( + self, input_data, input_types, expected_dtype): + transforms = [tft.ScaleTo01(columns=['x'])] + with beam.Pipeline() as p: + schema_data = ( + p + | beam.Create(input_data) + | beam.Map(lambda x: beam.Row(**x)).with_output_types( + beam.row_type.RowTypeConstraint.from_fields( + list(input_types.items())))) + transformed_data = schema_data | base.MLTransform( + artifact_location=self.artifact_location, transforms=transforms) + for name, typ in transformed_data.element_type._fields: + if name in expected_dtype: + self.assertEqual(expected_dtype[name], typ) + + def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self): + transforms = [tft.ScaleTo01(columns=['x'])] + with beam.Pipeline() as p: + with self.assertRaises(RuntimeError): + _ = ( + p + | beam.Create([{ + 'x': 1, 'y': 2.0 + }]) + | beam.WindowInto(beam.window.FixedWindows(1)) + | base.MLTransform( + transforms=transforms, + artifact_location=self.artifact_location, + artifact_mode=base.ArtifactMode.PRODUCE, + )) + + def test_ml_transform_on_multiple_columns_single_transform(self): + transforms = [tft.ScaleTo01(columns=['x', 'y'])] + batched_data = [{'x': [1, 2, 3], 'y': [1.0, 10.0, 20.0]}] + with beam.Pipeline() as p: + batched_result = ( + p + | beam.Create(batched_data) + | base.MLTransform( + transforms=transforms, artifact_location=self.artifact_location)) + expected_output_x = [ + np.array([0, 0.5, 1], dtype=np.float32), + ] + expected_output_y = [np.array([0, 0.47368422, 1], dtype=np.float32)] + actual_output_x = batched_result | beam.Map(lambda x: x.x) + actual_output_y = batched_result | beam.Map(lambda x: x.y) + assert_that( + actual_output_x, + equal_to(expected_output_x, equals_fn=np.array_equal)) + assert_that( + actual_output_y, + equal_to(expected_output_y, equals_fn=np.array_equal), + label='y') + + def test_ml_transforms_on_multiple_columns_multiple_transforms(self): + transforms = [ + tft.ScaleTo01(columns=['x']), + tft.ComputeAndApplyVocabulary(columns=['y']) + ] + batched_data = [{'x': [1, 2, 3], 'y': ['a', 'b', 'c']}] + with beam.Pipeline() as p: + batched_result = ( + p + | beam.Create(batched_data) + | base.MLTransform( + transforms=transforms, artifact_location=self.artifact_location)) + expected_output_x = [ + np.array([0, 0.5, 1], dtype=np.float32), + ] + expected_output_y = [np.array([2, 1, 0])] + actual_output_x = batched_result | beam.Map(lambda x: x.x) + actual_output_y = batched_result | beam.Map(lambda x: x.y) + + assert_that( + actual_output_x, + equal_to(expected_output_x, equals_fn=np.array_equal)) + assert_that( + actual_output_y, + equal_to(expected_output_y, equals_fn=np.array_equal), + label='actual_output_y') + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py new file mode 100644 index 000000000000..9754204f9fe0 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -0,0 +1,410 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import collections +import os +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Union + +import numpy as np + +import apache_beam as beam +import tensorflow as tf +import tensorflow_transform.beam as tft_beam +from apache_beam.ml.transforms.base import ArtifactMode +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.tft import _EXPECTED_TYPES +from apache_beam.ml.transforms.tft import TFTOperation +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import beam_metadata_io +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import metadata_io +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandler', +] + +RAW_DATA_METADATA_DIR = 'raw_data_metadata' +SCHEMA_FILE = 'schema.pbtxt' +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} +_primitive_types_to_typing_container_type = { + int: List[int], float: List[float], str: List[str], bytes: List[bytes] +} + +tft_process_handler_input_type = typing.Union[typing.NamedTuple, + beam.Row, + Dict[str, + typing.Union[str, + float, + int, + bytes, + np.ndarray]]] +tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]] + + +class ConvertScalarValuesToListValues(beam.DoFn): + def process( + self, + element, + ): + new_dict = {} + for key, value in element.items(): + if isinstance(value, + tuple(_primitive_types_to_typing_container_type.keys())): + new_dict[key] = [value] + else: + new_dict[key] = value + yield new_dict + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. + """ + def expand( + self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]] + ) -> beam.PCollection[common_types.InstanceDictType]: + """ + Args: + pcoll: A PCollection of NamedTuples or Rows. + Returns: + A PCollection of dictionaries. + """ + if isinstance(pcoll.element_type, RowTypeConstraint): + # Row instance + return pcoll | beam.Map(lambda x: x.as_dict()) + else: + # named tuple + return pcoll | beam.Map(lambda x: x._asdict()) + + +class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type, + tft_process_handler_output_type]): + def __init__( + self, + *, + artifact_location: str, + transforms: Optional[Sequence[TFTOperation]] = None, + preprocessing_fn: typing.Optional[typing.Callable] = None, + artifact_mode: str = ArtifactMode.PRODUCE): + """ + A handler class for processing data with TensorFlow Transform (TFT) + operations. This class is intended to be subclassed, with subclasses + implementing the `preprocessing_fn` method. + """ + self.transforms = transforms if transforms else [] + self.transformed_schema: Dict[str, type] = {} + self.artifact_location = artifact_location + self.preprocessing_fn = preprocessing_fn + self.artifact_mode = artifact_mode + if artifact_mode not in ['produce', 'consume']: + raise ValueError('artifact_mode must be either `produce` or `consume`.') + + def append_transform(self, transform): + self.transforms.append(transform) + + def _map_column_names_to_types(self, row_type): + """ + Return a dictionary of column names and types. + Args: + element_type: A type of the element. This could be a NamedTuple or a Row. + Returns: + A dictionary of column names and types. + """ + try: + if not isinstance(row_type, RowTypeConstraint): + row_type = RowTypeConstraint.from_user_type(row_type) + + inferred_types = {name: typ for name, typ in row_type._fields} + + for k, t in inferred_types.items(): + if t in _primitive_types_to_typing_container_type: + inferred_types[k] = _primitive_types_to_typing_container_type[t] + + # sometimes a numpy type can be provided as np.dtype('int64'). + # convert numpy.dtype to numpy type since both are same. + for name, typ in inferred_types.items(): + if isinstance(typ, np.dtype): + inferred_types[name] = typ.type + + return inferred_types + except: # pylint: disable=bare-except + return {} + + def _map_column_names_to_types_from_transforms(self): + column_type_mapping = {} + for transform in self.transforms: + for col in transform.columns: + if col not in column_type_mapping: + # we just need to dtype of first occurance of column in transforms. + class_name = transform.__class__.__name__ + if class_name not in _EXPECTED_TYPES: + raise KeyError( + f"Transform {class_name} is not registered with a supported " + "type. Please register the transform with a supported type " + "using register_input_dtype decorator.") + column_type_mapping[col] = _EXPECTED_TYPES[ + transform.__class__.__name__] + return column_type_mapping + + def get_raw_data_feature_spec( + self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]: + """ + Return a DatasetMetadata object to be used with + tft_beam.AnalyzeAndTransformDataset. + Args: + input_types: A dictionary of column names and types. + Returns: + A DatasetMetadata object. + """ + raw_data_feature_spec = {} + for key, value in input_types.items(): + raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column( + typ=value, col_name=key) + return raw_data_feature_spec + + def convert_raw_data_feature_spec_to_dataset_metadata( + self, raw_data_feature_spec) -> dataset_metadata.DatasetMetadata: + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + return raw_data_metadata + + def _get_raw_data_feature_spec_per_column( + self, typ: type, col_name: str) -> tf.io.VarLenFeature: + """ + Return a FeatureSpec object to be used with + tft_beam.AnalyzeAndTransformDataset + Args: + typ: A type of the column. + col_name: A name of the column. + Returns: + A FeatureSpec object. + """ + # lets conver the builtin types to typing types for consistency. + typ = native_type_compatibility.convert_builtin_to_typing(typ) + primitive_containers_type = ( + list, + collections.abc.Sequence, + ) + is_primitive_container = ( + typing.get_origin(typ) in primitive_containers_type) + + if is_primitive_container: + dtype = typing.get_args(typ)[0] + if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union: + raise RuntimeError( + f"Union type is not supported for column: {col_name}. " + f"Please pass a PCollection with valid schema for column " + f"{col_name} by passing a single type " + "in container. For example, List[int].") + elif issubclass(typ, np.generic) or typ in _default_type_to_tensor_type_map: + dtype = typ + else: + raise TypeError( + f"Unable to identify type: {typ} specified on column: {col_name}. " + f"Please provide a valid type from the following: " + f"{_default_type_to_tensor_type_map.keys()}") + return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) + + def get_raw_data_metadata( + self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + raw_data_feature_spec = self.get_raw_data_feature_spec(input_types) + return self.convert_raw_data_feature_spec_to_dataset_metadata( + raw_data_feature_spec) + + def write_transform_artifacts(self, transform_fn, location): + """ + Write transform artifacts to the given location. + Args: + transform_fn: A transform_fn object. + location: A location to write the artifacts. + Returns: + A PCollection of WriteTransformFn writing a TF transform graph. + """ + return ( + transform_fn + | 'Write Transform Artifacts' >> + transform_fn_io.WriteTransformFn(location)) + + def _fail_on_non_default_windowing(self, pcoll: beam.PCollection): + if not pcoll.windowing.is_default(): + raise RuntimeError( + "MLTransform only supports GlobalWindows when producing " + "artifacts such as min, max, variance etc over the dataset." + "Please use beam.WindowInto(beam.transforms.window.GlobalWindows()) " + "to convert your PCollection to GlobalWindow.") + + def process_data_fn( + self, inputs: Dict[str, common_types.ConsistentTensorType] + ) -> Dict[str, common_types.ConsistentTensorType]: + """ + This method is used in the AnalyzeAndTransformDataset step. It applies + the transforms to the `inputs` in sequential order on the columns + provided for a given transform. + Args: + inputs: A dictionary of column names and data. + Returns: + A dictionary of column names and transformed data. + """ + outputs = inputs.copy() + for transform in self.transforms: + columns = transform.columns + for col in columns: + intermediate_result = transform.apply( + outputs[col], output_column_name=col) + for key, value in intermediate_result.items(): + outputs[key] = value + return outputs + + def _get_transformed_data_schema( + self, + metadata: dataset_metadata.DatasetMetadata, + ): + schema = metadata._schema + transformed_types = {} + for feature in schema.feature: + name = feature.name + feature_type = feature.type + if feature_type == schema_pb2.FeatureType.FLOAT: + transformed_types[name] = typing.Sequence[np.float32] + elif feature_type == schema_pb2.FeatureType.INT: + transformed_types[name] = typing.Sequence[np.int64] # type: ignore[assignment] + else: + transformed_types[name] = typing.Sequence[bytes] # type: ignore[assignment] + return transformed_types + + def process_data( + self, raw_data: beam.PCollection[tft_process_handler_input_type] + ) -> beam.PCollection[tft_process_handler_output_type]: + """ + This method also computes the required dataset metadata for the tft + AnalyzeDataset/TransformDataset step. + + This method uses tensorflow_transform's Analyze step to produce the + artifacts and Transform step to apply the transforms on the data. + Artifacts are only produced if the artifact_mode is set to `produce`. + If artifact_mode is set to `consume`, then the artifacts are read from the + artifact_location, which was previously used to store the produced + artifacts. + """ + if self.artifact_mode == ArtifactMode.PRODUCE: + # If we are computing artifacts, we should fail for windows other than + # default windowing since for example, for a fixed window, each window can + # be treated as a separate dataset and we might need to compute artifacts + # for each window. This is not supported yet. + self._fail_on_non_default_windowing(raw_data) + element_type = raw_data.element_type + column_type_mapping = {} + if (isinstance(element_type, RowTypeConstraint) or + native_type_compatibility.match_is_named_tuple(element_type)): + column_type_mapping = self._map_column_names_to_types( + row_type=element_type) + # convert Row or NamedTuple to Dict + raw_data = ( + raw_data + | ConvertNamedTupleToDict().with_output_types( + Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore + # AnalyzeAndTransformDataset raise type hint since this is + # schema'd PCollection and the current output type would be a + # custom type(NamedTuple) or a beam.Row type. + else: + column_type_mapping = self._map_column_names_to_types_from_transforms() + raw_data_metadata = self.get_raw_data_metadata( + input_types=column_type_mapping) + # Write untransformed metadata to a file so that it can be re-used + # during Transform step. + metadata_io.write_metadata( + metadata=raw_data_metadata, + path=os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR)) + else: + # Read the metadata from the artifact_location. + if not os.path.exists(os.path.join( + self.artifact_location, RAW_DATA_METADATA_DIR, SCHEMA_FILE)): + raise FileNotFoundError( + "Artifacts not found at location: %s when artifact_mode=consume." + "Make sure you've run the pipeline in `produce` mode using " + "this artifact location before setting artifact_mode to `consume`." + % os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR)) + raw_data_metadata = metadata_io.read_metadata( + os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR)) + + # To maintain consistency by outputting numpy array all the time, + # whether a scalar value or list or np array is passed as input, + # we will convert scalar values to list values and TFT will ouput + # numpy array all the time. + raw_data |= beam.ParDo(ConvertScalarValuesToListValues()) + + with tft_beam.Context(temp_dir=self.artifact_location): + data = (raw_data, raw_data_metadata) + if self.artifact_mode == ArtifactMode.PRODUCE: + transform_fn = ( + data + | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(self.process_data_fn)) + self.write_transform_artifacts(transform_fn, self.artifact_location) + else: + transform_fn = ( + raw_data.pipeline + | "ReadTransformFn" >> tft_beam.ReadTransformFn( + self.artifact_location)) + (transformed_dataset, transformed_metadata) = ( + ((raw_data, raw_data_metadata), transform_fn) + | "TransformDataset" >> tft_beam.TransformDataset()) + + if isinstance(transformed_metadata, beam_metadata_io.BeamDatasetMetadata): + self.transformed_schema = self._get_transformed_data_schema( + metadata=transformed_metadata.dataset_metadata) + else: + self.transformed_schema = self._get_transformed_data_schema( + transformed_metadata) + + # We will a pass a schema'd PCollection to the next step. + # So we will use a RowTypeConstraint to create a schema'd PCollection. + # this is needed since new columns are included in the + # transformed_dataset. + row_type = RowTypeConstraint.from_fields( + list(self.transformed_schema.items())) + + transformed_dataset |= "ConvertToRowType" >> beam.Map( + lambda x: beam.Row(**x)).with_output_types(row_type) + return transformed_dataset diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py new file mode 100644 index 000000000000..878006550dcf --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -0,0 +1,355 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import os +import shutil +import sys +import tempfile +import typing +import unittest +from typing import List +from typing import NamedTuple +from typing import Union + +import numpy as np +from parameterized import parameterized + +import apache_beam as beam + +# pylint: disable=wrong-import-position, ungrouped-imports +try: + from apache_beam.ml.transforms import handlers + from apache_beam.ml.transforms import tft + from apache_beam.ml.transforms.tft import TFTOperation + from apache_beam.testing.util import assert_that + from apache_beam.testing.util import equal_to + import tensorflow as tf + from tensorflow_transform.tf_metadata import dataset_metadata + from tensorflow_transform.tf_metadata import schema_utils +except ImportError: + tft = None # type: ignore[assignment] + +if not tft: + raise unittest.SkipTest('tensorflow_transform is not installed.') + + +class _AddOperation(TFTOperation): + def apply(self, inputs, output_column_name, **kwargs): + return {output_column_name: inputs + 1} + + +class _MultiplyOperation(TFTOperation): + def apply(self, inputs, output_column_name, **kwargs): + return {output_column_name: inputs * 10} + + +class _FakeOperationWithArtifacts(TFTOperation): + def apply(self, inputs, output_column_name, **kwargs): + return { + **{ + output_column_name: inputs + }, + **(self.get_artifacts(inputs, 'artifact')) + } + + def get_artifacts(self, data, col_name): + return {'artifact': tf.convert_to_tensor([1])} + + +class UnBatchedIntType(NamedTuple): + x: int + + +class BatchedIntType(NamedTuple): + x: List[int] + + +class BatchedNumpyType(NamedTuple): + x: np.int64 + + +class TFTProcessHandlerTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + @parameterized.expand([ + ({ + 'x': 1, 'y': 2 + }, ['x'], { + 'x': 20, 'y': 2 + }), + ({ + 'x': 1, 'y': 2 + }, ['x', 'y'], { + 'x': 20, 'y': 30 + }), + ]) + def test_tft_operation_preprocessing_fn( + self, inputs, columns, expected_result): + add_fn = _AddOperation(columns=columns) + mul_fn = _MultiplyOperation(columns=columns) + process_handler = handlers.TFTProcessHandler( + transforms=[add_fn, mul_fn], artifact_location=self.artifact_location) + actual_result = process_handler.process_data_fn(inputs) + self.assertDictEqual(actual_result, expected_result) + + def test_preprocessing_fn_with_artifacts(self): + process_handler = handlers.TFTProcessHandler( + transforms=[_FakeOperationWithArtifacts(columns=['x'])], + artifact_location=self.artifact_location) + inputs = {'x': [1, 2, 3]} + preprocessing_fn = process_handler.process_data_fn + actual_result = preprocessing_fn(inputs) + expected_result = {'x': [1, 2, 3], 'artifact': tf.convert_to_tensor([1])} + self.assertDictEqual(actual_result, expected_result) + + def test_input_type_from_schema_named_tuple_pcoll_unbatched(self): + non_batched_data = [{'x': 1}] + with beam.Pipeline() as p: + data = ( + p | beam.Create(non_batched_data) + | beam.Map(lambda x: UnBatchedIntType(**x)).with_output_types( + UnBatchedIntType)) + element_type = data.element_type + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + inferred_input_type = process_handler._map_column_names_to_types( + element_type) + expected_input_type = dict(x=List[int]) + + self.assertEqual(inferred_input_type, expected_input_type) + + def test_input_type_from_schema_named_tuple_pcoll_batched(self): + batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + data = ( + p | beam.Create(batched_data) + | beam.Map(lambda x: BatchedIntType(**x)).with_output_types( + BatchedIntType)) + element_type = data.element_type + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + inferred_input_type = process_handler._map_column_names_to_types( + element_type) + expected_input_type = dict(x=List[int]) + self.assertEqual(inferred_input_type, expected_input_type) + + def test_input_type_from_row_type_pcoll_unbatched(self): + non_batched_data = [{'x': 1}] + with beam.Pipeline() as p: + data = ( + p | beam.Create(non_batched_data) + | beam.Map(lambda ele: beam.Row(x=int(ele['x'])))) + element_type = data.element_type + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + inferred_input_type = process_handler._map_column_names_to_types( + element_type) + expected_input_type = dict(x=List[int]) + self.assertEqual(inferred_input_type, expected_input_type) + + def test_input_type_from_row_type_pcoll_batched(self): + batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + data = ( + p | beam.Create(batched_data) + | beam.Map(lambda ele: beam.Row(x=list(ele['x']))).with_output_types( + beam.row_type.RowTypeConstraint.from_fields([('x', List[int])]))) + + element_type = data.element_type + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + inferred_input_type = process_handler._map_column_names_to_types( + element_type) + expected_input_type = dict(x=List[int]) + self.assertEqual(inferred_input_type, expected_input_type) + + def test_input_type_from_named_tuple_pcoll_batched_numpy(self): + batched = [{ + 'x': np.array([1, 2, 3], dtype=np.int64) + }, { + 'x': np.array([4, 5, 6], dtype=np.int64) + }] + with beam.Pipeline() as p: + data = ( + p | beam.Create(batched) + | beam.Map(lambda x: BatchedNumpyType(**x)).with_output_types( + BatchedNumpyType)) + element_type = data.element_type + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + inferred_input_type = process_handler._map_column_names_to_types( + element_type) + expected_type = dict(x=np.int64) + self.assertEqual(inferred_input_type, expected_type) + + def test_tensorflow_raw_data_metadata_primitive_types(self): + input_types = dict(x=int, y=float, k=bytes, l=str) + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + + for col_name, typ in input_types.items(): + feature_spec = process_handler._get_raw_data_feature_spec_per_column( + typ=typ, col_name=col_name) + self.assertEqual( + handlers._default_type_to_tensor_type_map[typ], feature_spec.dtype) + self.assertIsInstance(feature_spec, tf.io.VarLenFeature) + + def test_tensorflow_raw_data_metadata_primitive_types_in_containers(self): + input_types = dict([("x", List[int]), ("y", List[float]), + ("k", List[bytes]), ("l", List[str])]) + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + for col_name, typ in input_types.items(): + feature_spec = process_handler._get_raw_data_feature_spec_per_column( + typ=typ, col_name=col_name) + self.assertIsInstance(feature_spec, tf.io.VarLenFeature) + + @unittest.skipIf(sys.version_info < (3, 9), "not supported in python<3.9") + def test_tensorflow_raw_data_metadata_primitive_native_container_types(self): + input_types = dict([("x", list[int]), ("y", list[float]), + ("k", list[bytes]), ("l", list[str])]) + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + for col_name, typ in input_types.items(): + feature_spec = process_handler._get_raw_data_feature_spec_per_column( + typ=typ, col_name=col_name) + self.assertIsInstance(feature_spec, tf.io.VarLenFeature) + + def test_tensorflow_raw_data_metadata_numpy_types(self): + input_types = dict(x=np.int64, y=np.float32, z=List[np.int64]) + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + for col_name, typ in input_types.items(): + feature_spec = process_handler._get_raw_data_feature_spec_per_column( + typ=typ, col_name=col_name) + self.assertIsInstance(feature_spec, tf.io.VarLenFeature) + + def test_tensorflow_raw_data_metadata_union_type_in_single_column(self): + input_types = dict(x=Union[int, float]) + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + with self.assertRaises(TypeError): + for col_name, typ in input_types.items(): + _ = process_handler._get_raw_data_feature_spec_per_column( + typ=typ, col_name=col_name) + + def test_tensorflow_raw_data_metadata_dtypes(self): + input_types = dict(x=np.int32, y=np.float64) + expected_dtype = dict(x=np.int64, y=np.float32) + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + for col_name, typ in input_types.items(): + feature_spec = process_handler._get_raw_data_feature_spec_per_column( + typ=typ, col_name=col_name) + self.assertEqual(expected_dtype[col_name], feature_spec.dtype) + + def test_tft_process_handler_default_transform_types(self): + transforms = [ + tft.ScaleTo01(columns=['x']), + tft.ScaleToZScore(columns=['y']), + tft.Bucketize(columns=['z'], num_buckets=2), + tft.ComputeAndApplyVocabulary(columns=['w']) + ] + process_handler = handlers.TFTProcessHandler( + transforms=transforms, artifact_location=self.artifact_location) + column_type_mapping = ( + process_handler._map_column_names_to_types_from_transforms()) + expected_column_type_mapping = { + 'x': float, 'y': float, 'z': float, 'w': str + } + self.assertDictEqual(column_type_mapping, expected_column_type_mapping) + + expected_tft_raw_data_feature_spec = { + 'x': tf.io.VarLenFeature(tf.float32), + 'y': tf.io.VarLenFeature(tf.float32), + 'z': tf.io.VarLenFeature(tf.float32), + 'w': tf.io.VarLenFeature(tf.string) + } + actual_tft_raw_data_feature_spec = ( + process_handler.get_raw_data_feature_spec(column_type_mapping)) + self.assertDictEqual( + actual_tft_raw_data_feature_spec, expected_tft_raw_data_feature_spec) + + def test_tft_process_handler_transformed_data_schema(self): + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location) + raw_data_feature_spec = { + 'x': tf.io.VarLenFeature(tf.float32), + 'y': tf.io.VarLenFeature(tf.float32), + 'z': tf.io.VarLenFeature(tf.string), + } + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + + expected_transformed_data_schema = { + 'x': typing.Sequence[np.float32], + 'y': typing.Sequence[np.float32], + 'z': typing.Sequence[bytes] + } + + actual_transformed_data_schema = ( + process_handler._get_transformed_data_schema(raw_data_metadata)) + self.assertDictEqual( + actual_transformed_data_schema, expected_transformed_data_schema) + + def test_tft_process_handler_verify_artifacts(self): + with beam.Pipeline() as p: + raw_data = ( + p + | beam.Create([{ + 'x': np.array([1, 3]) + }, { + 'x': np.array([4, 6]) + }])) + process_handler = handlers.TFTProcessHandler( + transforms=[tft.ScaleTo01(columns=['x'])], + artifact_location=self.artifact_location, + ) + _ = process_handler.process_data(raw_data) + + self.assertTrue( + os.path.exists( + os.path.join( + self.artifact_location, handlers.RAW_DATA_METADATA_DIR))) + self.assertTrue( + os.path.exists( + os.path.join( + self.artifact_location, + handlers.RAW_DATA_METADATA_DIR, + handlers.SCHEMA_FILE))) + + with beam.Pipeline() as p: + raw_data = (p | beam.Create([{'x': np.array([2, 5])}])) + process_handler = handlers.TFTProcessHandler( + artifact_location=self.artifact_location, artifact_mode='consume') + transformed_data = process_handler.process_data(raw_data) + transformed_data |= beam.Map(lambda x: x.x) + + # the previous min is 1 and max is 6. So this should scale by (1, 6) + assert_that( + transformed_data, + equal_to([np.array([0.2, 0.8], dtype=np.float32)], + equals_fn=np.array_equal)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py new file mode 100644 index 000000000000..329a10a74ca1 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -0,0 +1,440 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module defines a set of data processing transforms that can be used +to perform common data transformations on a dataset. These transforms are +implemented using the TensorFlow Transform (TFT) library. The transforms +in this module are intended to be used in conjunction with the +MLTransform class, which provides a convenient interface for +applying a sequence of data processing transforms to a dataset. + +See the documentation for MLTransform for more details. + +Note: The data processing transforms defined in this module don't +perform the transformation immediately. Instead, it returns a +configured operation object, which encapsulates the details of the +transformation. The actual computation takes place later in the Apache Beam +pipeline, after all transformations are set up and the pipeline is run. +""" + +# pytype: skip-file + +import logging +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + +import tensorflow as tf +import tensorflow_transform as tft +from apache_beam.ml.transforms.base import BaseOperation +from tensorflow_transform import analyzers +from tensorflow_transform import common_types +from tensorflow_transform import tf_utils + +__all__ = [ + 'ComputeAndApplyVocabulary', + 'ScaleToZScore', + 'ScaleTo01', + 'ApplyBuckets', + 'Bucketize', + 'TFIDF', + 'TFTOperation', +] + +# Register the expected input types for each operation +# this will be used to determine schema for the tft.AnalyzeDataset +_EXPECTED_TYPES: Dict[str, Union[int, str, float]] = {} + +_LOGGER = logging.getLogger(__name__) + + +def register_input_dtype(type): + def wrapper(fn): + _EXPECTED_TYPES[fn.__name__] = type + return fn + + return wrapper + + +class TFTOperation(BaseOperation[common_types.TensorType, + common_types.TensorType]): + def __init__(self, columns: List[str]) -> None: + """ + Base Operation class for TFT data processing transformations. + Processing logic for the transformation is defined in the + apply() method. If you have a custom transformation that is not + supported by the existing transforms, you can extend this class + and implement the apply() method. + Args: + columns: List of column names to apply the transformation. + """ + super().__init__(columns) + if not columns: + raise RuntimeError( + "Columns are not specified. Please specify the column for the " + " op %s" % self.__class__.__name__) + + def get_artifacts(self, data: common_types.TensorType, + col_name: str) -> Dict[str, common_types.TensorType]: + """ + Returns the artifacts generated by the operation. + """ + return {} + + +@register_input_dtype(str) +class ComputeAndApplyVocabulary(TFTOperation): + def __init__( + self, + columns: List[str], + *, + default_value: Any = -1, + top_k: Optional[int] = None, + frequency_threshold: Optional[int] = None, + num_oov_buckets: int = 0, + vocab_filename: Optional[str] = None, + name: Optional[str] = None): + """ + This function computes the vocabulary for the given columns of incoming + data. The transformation converts the input values to indices of the + vocabulary. + + Args: + columns: List of column names to apply the transformation. + default_value: (Optional) The value to use for out-of-vocabulary values. + top_k: (Optional) The number of most frequent tokens to keep. + frequency_threshold: (Optional) Limit the generated vocabulary only to + elements whose absolute frequency is >= to the supplied threshold. + If set to None, the full vocabulary is generated. + num_oov_buckets: Any lookup of an out-of-vocabulary token will return a + bucket ID based on its hash if `num_oov_buckets` is greater than zero. + Otherwise it is assigned the `default_value`. + vocab_filename: The file name for the vocabulary file. If not provided, + the default name would be `compute_and_apply_vocab' + NOTE in order to make your pipelines resilient to implementation + details please set `vocab_filename` when you are using + the vocab_filename on a downstream component. + """ + super().__init__(columns) + self._default_value = default_value + self._top_k = top_k + self._frequency_threshold = frequency_threshold + self._num_oov_buckets = num_oov_buckets + self._vocab_filename = vocab_filename if vocab_filename else ( + 'compute_and_apply_vocab') + self._name = name + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + return { + output_column_name: tft.compute_and_apply_vocabulary( + x=data, + default_value=self._default_value, + top_k=self._top_k, + frequency_threshold=self._frequency_threshold, + num_oov_buckets=self._num_oov_buckets, + vocab_filename=self._vocab_filename, + name=self._name) + } + + +@register_input_dtype(float) +class ScaleToZScore(TFTOperation): + def __init__( + self, + columns: List[str], + *, + elementwise: bool = False, + name: Optional[str] = None): + """ + This function performs a scaling transformation on the specified columns of + the incoming data. It processes the input data such that it's normalized + to have a mean of 0 and a variance of 1. The transformation achieves this + by subtracting the mean from the input data and then dividing it by the + square root of the variance. + + Args: + columns: A list of column names to apply the transformation on. + elementwise: If True, the transformation is applied elementwise. + Otherwise, the transformation is applied on the entire column. + name: A name for the operation (optional). + + scale_to_z_score also outputs additional artifacts. The artifacts are + mean, which is the mean value in the column, and var, which is the + variance in the column. The artifacts are stored in the column + named with the suffix _mean and _var + respectively. + """ + super().__init__(columns) + self.elementwise = elementwise + self.name = name + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + artifacts = self.get_artifacts(data, output_column_name) + output_dict = { + output_column_name: tft.scale_to_z_score( + x=data, elementwise=self.elementwise, name=self.name) + } + if artifacts is not None: + output_dict.update(artifacts) + return output_dict + + def get_artifacts(self, data: common_types.TensorType, + col_name: str) -> Dict[str, common_types.TensorType]: + mean_var = tft.analyzers._mean_and_var(data) + shape = [tf.shape(data)[0], 1] + return { + col_name + '_mean': tf.broadcast_to(mean_var[0], shape), + col_name + '_var': tf.broadcast_to(mean_var[1], shape), + } + + +@register_input_dtype(float) +class ScaleTo01(TFTOperation): + def __init__( + self, + columns: List[str], + elementwise: bool = False, + name: Optional[str] = None): + """ + This function applies a scaling transformation on the given columns + of incoming data. The transformation scales the input values to the + range [0, 1] by dividing each value by the maximum value in the + column. + + Args: + columns: A list of column names to apply the transformation on. + elementwise: If True, the transformation is applied elementwise. + Otherwise, the transformation is applied on the entire column. + name: A name for the operation (optional). + + ScaleTo01 also outputs additional artifacts. The artifacts are + max, which is the maximum value in the column, and min, which is the + minimum value in the column. The artifacts are stored in the column + named with the suffix _min and _max + respectively. + + """ + super().__init__(columns) + self.elementwise = elementwise + self.name = name + + def get_artifacts(self, data: common_types.TensorType, + col_name: str) -> Dict[str, common_types.TensorType]: + shape = [tf.shape(data)[0], 1] + return { + col_name + '_min': tf.broadcast_to(tft.min(data), shape), + col_name + '_max': tf.broadcast_to(tft.max(data), shape) + } + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + artifacts = self.get_artifacts(data, output_column_name) + output = tft.scale_to_0_1( + x=data, elementwise=self.elementwise, name=self.name) + + output_dict = {output_column_name: output} + if artifacts is not None: + output_dict.update(artifacts) + return output_dict + + +@register_input_dtype(float) +class ApplyBuckets(TFTOperation): + def __init__( + self, + columns: List[str], + bucket_boundaries: Iterable[Union[int, float]], + name: Optional[str] = None): + """ + This functions is used to map the element to a positive index i for + which bucket_boundaries[i-1] <= element < bucket_boundaries[i], + if it exists. If input < bucket_boundaries[0], then element is + mapped to 0. If element >= bucket_boundaries[-1], then element is + mapped to len(bucket_boundaries). NaNs are mapped to + len(bucket_boundaries). + + Args: + columns: A list of column names to apply the transformation on. + bucket_boundaries: A rank 2 Tensor or list representing the bucket + boundaries sorted in ascending order. + name: (Optional) A string that specifies the name of the operation. + """ + super().__init__(columns) + self.bucket_boundaries = [bucket_boundaries] + self.name = name + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + output = { + output_column_name: tft.apply_buckets( + x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) + } + return output + + +@register_input_dtype(float) +class Bucketize(TFTOperation): + def __init__( + self, + columns: List[str], + num_buckets: int, + *, + epsilon: Optional[float] = None, + elementwise: bool = False, + name: Optional[str] = None): + """ + This function applies a bucketizing transformation on the given columns + of incoming data. The transformation splits the input data range into + a set of consecutive bins/buckets, and converts the input values to + bucket IDs (integers) where each ID corresponds to a particular bin. + + Args: + columns: List of column names to apply the transformation. + num_buckets: Number of buckets to be created. + epsilon: (Optional) A float number that specifies the error tolerance + when computing quantiles, so that we guarantee that any value x will + have a quantile q such that x is in the interval + [q - epsilon, q + epsilon] (or the symmetric interval for even + num_buckets). Must be greater than 0.0. + elementwise: (Optional) A boolean that specifies whether the quantiles + should be computed on an element-wise basis. If False, the quantiles + are computed globally. + name: (Optional) A string that specifies the name of the operation. + """ + super().__init__(columns) + self.num_buckets = num_buckets + self.epsilon = epsilon + self.elementwise = elementwise + self.name = name + + def get_artifacts(self, data: common_types.TensorType, + col_name: str) -> Dict[str, common_types.TensorType]: + num_buckets = self.num_buckets + epsilon = self.epsilon + elementwise = self.elementwise + + if num_buckets < 1: + raise ValueError('Invalid num_buckets %d' % num_buckets) + + if isinstance(data, (tf.SparseTensor, tf.RaggedTensor)) and elementwise: + raise ValueError( + 'bucketize requires `x` to be dense if `elementwise=True`') + + x_values = tf_utils.get_values(data) + + if epsilon is None: + # See explanation in args documentation for epsilon. + epsilon = min(1.0 / num_buckets, 0.01) + + quantiles = analyzers.quantiles( + x_values, num_buckets, epsilon, reduce_instance_dims=not elementwise) + shape = [ + tf.shape(data)[0], num_buckets - 1 if num_buckets > 1 else num_buckets + ] + # These quantiles are used as the bucket boundaries in the later stages. + # Should we change the prefix _quantiles to _bucket_boundaries? + return {col_name + '_quantiles': tf.broadcast_to(quantiles, shape)} + + def apply(self, data: common_types.TensorType, + output_column_name: str) -> Dict[str, common_types.TensorType]: + artifacts = self.get_artifacts(data, output_column_name) + output = { + output_column_name: tft.bucketize( + x=data, + num_buckets=self.num_buckets, + epsilon=self.epsilon, + elementwise=self.elementwise, + name=self.name) + } + if artifacts is not None: + output.update(artifacts) + return output + + +@register_input_dtype(float) +class TFIDF(TFTOperation): + def __init__( + self, + columns: List[str], + vocab_size: Optional[int] = None, + smooth: bool = True, + name: Optional[str] = None, + ): + """ + This function applies a tf-idf transformation on the given columns + of incoming data. + + TFIDF outputs two artifacts for each column: the vocabu index and + the tfidf weight. The vocabu index is a mapping from the original + vocabulary to the new vocabulary. The tfidf weight is a mapping + from the original vocabulary to the tfidf score. + + Input passed to the TFIDF is not modified and used to calculate the + required artifacts. + + Args: + columns: List of column names to apply the transformation. + vocab_size: (Optional) An integer that specifies the size of the + vocabulary. Defaults to None. + + If vocab_size is None, then the size of the vocabulary is + determined by `tft.get_num_buckets_for_transformed_feature`. + smooth: (Optional) A boolean that specifies whether to apply + smoothing to the tf-idf score. Defaults to True. + name: (Optional) A string that specifies the name of the operation. + """ + super().__init__(columns) + self.vocab_size = vocab_size + self.smooth = smooth + self.name = name + self.tfidf_weight = None + + def apply( + self, data: tf.SparseTensor, output_column_name: str) -> tf.SparseTensor: + + if self.vocab_size is None: + try: + _LOGGER.info( + 'vocab_size is not specified. Trying to infer vocab_size ' + 'from the input data using ' + 'tft.get_num_buckets_for_transformed_feature.') + vocab_size = tft.get_num_buckets_for_transformed_feature(data) + except RuntimeError: + raise RuntimeError( + 'vocab_size is not specified. Tried to infer vocab_size from the ' + 'input data using tft.get_num_buckets_for_transformed_feature, but ' + 'failed. Please specify vocab_size explicitly.') + else: + vocab_size = self.vocab_size + + vocab_index, tfidf_weight = tft.tfidf( + data, + vocab_size, + self.smooth, + self.name + ) + + output = { + output_column_name + '_vocab_index': vocab_index, + output_column_name + '_tfidf_weight': tfidf_weight + } + return output diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py new file mode 100644 index 000000000000..66578c7366dc --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -0,0 +1,395 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +import apache_beam as beam +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=wrong-import-order, wrong-import-position +try: + from apache_beam.ml.transforms import base + from apache_beam.ml.transforms import tft +except ImportError: + tft = None # type: ignore[assignment] + +if not tft: + raise unittest.SkipTest('tensorflow_transform is not installed.') + +z_score_expected = {'x_mean': 3.5, 'x_var': 2.9166666666666665} + + +def assert_z_score_artifacts(element): + element = element.as_dict() + assert 'x_mean' in element + assert 'x_var' in element + assert element['x_mean'] == z_score_expected['x_mean'] + assert element['x_var'] == z_score_expected['x_var'] + + +def assert_ScaleTo01_artifacts(element): + element = element.as_dict() + assert 'x_min' in element + assert 'x_max' in element + assert element['x_min'] == 1 + assert element['x_max'] == 6 + + +def assert_bucketize_artifacts(element): + element = element.as_dict() + assert 'x_quantiles' in element + assert np.array_equal( + element['x_quantiles'], np.array([3, 5], dtype=np.float32)) + + +class ScaleZScoreTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_z_score_unbatched(self): + unbatched_data = [{ + 'x': 1 + }, { + 'x': 2 + }, { + 'x': 3 + }, { + 'x': 4 + }, { + 'x': 5 + }, { + 'x': 6 + }] + + with beam.Pipeline() as p: + unbatched_result = ( + p + | "unbatchedCreate" >> beam.Create(unbatched_data) + | "unbatchedMLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ScaleToZScore(columns=['x']))) + _ = (unbatched_result | beam.Map(assert_z_score_artifacts)) + + def test_z_score_batched(self): + batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + batched_result = ( + p + | "batchedCreate" >> beam.Create(batched_data) + | "batchedMLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ScaleToZScore(columns=['x']))) + _ = (batched_result | beam.Map(assert_z_score_artifacts)) + + +class ScaleTo01Test(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_ScaleTo01_batched(self): + batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + batched_result = ( + p + | "batchedCreate" >> beam.Create(batched_data) + | "batchedMLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ScaleTo01(columns=['x']))) + _ = (batched_result | beam.Map(assert_ScaleTo01_artifacts)) + + expected_output = [ + np.array([0, 0.2, 0.4], dtype=np.float32), + np.array([0.6, 0.8, 1], dtype=np.float32) + ] + actual_output = (batched_result | beam.Map(lambda x: x.x)) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + def test_ScaleTo01_unbatched(self): + unbatched_data = [{ + 'x': 1 + }, { + 'x': 2 + }, { + 'x': 3 + }, { + 'x': 4 + }, { + 'x': 5 + }, { + 'x': 6 + }] + with beam.Pipeline() as p: + unbatched_result = ( + p + | "unbatchedCreate" >> beam.Create(unbatched_data) + | "unbatchedMLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ScaleTo01(columns=['x']))) + + _ = (unbatched_result | beam.Map(assert_ScaleTo01_artifacts)) + expected_output = ( + np.array([0], dtype=np.float32), + np.array([0.2], dtype=np.float32), + np.array([0.4], dtype=np.float32), + np.array([0.6], dtype=np.float32), + np.array([0.8], dtype=np.float32), + np.array([1], dtype=np.float32)) + actual_output = (unbatched_result | beam.Map(lambda x: x.x)) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + +class BucketizeTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_bucketize_unbatched(self): + unbatched = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}, {'x': 5}, {'x': 6}] + with beam.Pipeline() as p: + unbatched_result = ( + p + | "unbatchedCreate" >> beam.Create(unbatched) + | "unbatchedMLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.Bucketize(columns=['x'], num_buckets=3))) + _ = (unbatched_result | beam.Map(assert_bucketize_artifacts)) + + transformed_data = (unbatched_result | beam.Map(lambda x: x.x)) + expected_data = [ + np.array([0]), + np.array([0]), + np.array([1]), + np.array([1]), + np.array([2]), + np.array([2]) + ] + assert_that( + transformed_data, equal_to(expected_data, equals_fn=np.array_equal)) + + def test_bucketize_batched(self): + batched = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] + with beam.Pipeline() as p: + batched_result = ( + p + | "batchedCreate" >> beam.Create(batched) + | "batchedMLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.Bucketize(columns=['x'], num_buckets=3))) + _ = (batched_result | beam.Map(assert_bucketize_artifacts)) + + transformed_data = ( + batched_result + | "TransformedColumnX" >> beam.Map(lambda ele: ele.x)) + expected_data = [ + np.array([0, 0, 1], dtype=np.int64), + np.array([1, 2, 2], dtype=np.int64) + ] + assert_that( + transformed_data, equal_to(expected_data, equals_fn=np.array_equal)) + + @parameterized.expand([ + (range(1, 10), [4, 7]), + (range(9, 0, -1), [4, 7]), + (range(19, 0, -1), [10]), + (range(1, 100), [25, 50, 75]), + # similar to the above but with odd number of elements + (range(1, 100, 2), [25, 51, 75]), + (range(99, 0, -1), range(10, 100, 10)) + ]) + def test_bucketize_boundaries(self, test_input, expected_boundaries): + # boundaries are outputted as artifacts for the Bucketize transform. + data = [{'x': [i]} for i in test_input] + num_buckets = len(expected_boundaries) + 1 + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.Bucketize(columns=['x'], num_buckets=num_buckets))) + actual_boundaries = ( + result + | beam.Map(lambda x: x.as_dict()) + | beam.Map(lambda x: x['x_quantiles'])) + + def assert_boundaries(actual_boundaries): + assert np.array_equal(actual_boundaries, expected_boundaries) + + _ = (actual_boundaries | beam.Map(assert_boundaries)) + + +class ApplyBucketsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + @parameterized.expand([ + (range(1, 100), [25, 50, 75]), + (range(1, 100, 2), [25, 51, 75]), + ]) + def test_apply_buckets(self, test_inputs, bucket_boundaries): + with beam.Pipeline() as p: + data = [{'x': [i]} for i in test_inputs] + result = ( + p + | "Create" >> beam.Create(data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ApplyBuckets( + columns=['x'], bucket_boundaries=bucket_boundaries))) + expected_output = [] + bucket = 0 + for x in sorted(test_inputs): + # Increment the bucket number when crossing the boundary + if (bucket < len(bucket_boundaries) and x >= bucket_boundaries[bucket]): + bucket += 1 + expected_output.append(np.array([bucket])) + + actual_output = (result | beam.Map(lambda x: x.x)) + assert_that( + actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + + +class ComputeAndApplyVocabTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_compute_and_apply_vocabulary_unbatched_inputs(self): + batch_size = 100 + num_instances = batch_size + 1 + input_data = [{ + 'x': '%.10i' % i, # Front-padded to facilitate lexicographic sorting. + } for i in range(num_instances)] + + expected_data = [{ + 'x': (len(input_data) - 1) - i, # Due to reverse lexicographic sorting. + } for i in range(len(input_data))] + + with beam.Pipeline() as p: + actual_data = ( + p + | "Create" >> beam.Create(input_data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ComputeAndApplyVocabulary(columns=['x']))) + actual_data |= beam.Map(lambda x: x.as_dict()) + + assert_that(actual_data, equal_to(expected_data)) + + def test_compute_and_apply_vocabulary_batched(self): + batch_size = 100 + num_instances = batch_size + 1 + input_data = [ + { + 'x': ['%.10i' % i, '%.10i' % (i + 1), '%.10i' % (i + 2)], + # Front-padded to facilitate lexicographic sorting. + } for i in range(0, num_instances, 3) + ] + + # since we have 3 elements in a single batch, multiply with 3 for + # each iteration i on the expected output. + excepted_data = [ + np.array([(len(input_data) * 3 - 1) - i, + (len(input_data) * 3 - 1) - i - 1, + (len(input_data) * 3 - 1) - i - 2], + dtype=np.int64) # Front-padded to facilitate lexicographic + # sorting. + for i in range(0, len(input_data) * 3, 3) + ] + + with beam.Pipeline() as p: + result = ( + p + | "Create" >> beam.Create(input_data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location).with_transform( + tft.ComputeAndApplyVocabulary(columns=['x']))) + actual_output = (result | beam.Map(lambda x: x.x)) + assert_that( + actual_output, equal_to(excepted_data, equals_fn=np.array_equal)) + + +class TFIDIFTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.artifact_location) + + def test_tfidf_batched_compute_vocab_size_during_runtime(self): + raw_data = [ + dict(x=["I", "like", "pie", "pie", "pie"]), + dict(x=["yum", "yum", "pie"]) + ] + with beam.Pipeline() as p: + transforms = [ + tft.ComputeAndApplyVocabulary(columns=['x']), + tft.TFIDF(columns=['x']) + ] + actual_output = ( + p + | "Create" >> beam.Create(raw_data) + | "MLTransform" >> base.MLTransform( + artifact_location=self.artifact_location, transforms=transforms)) + actual_output |= beam.Map(lambda x: x.as_dict()) + + def equals_fn(a, b): + is_equal = True + for key, value in a.items(): + value_b = a[key] + is_equal = is_equal and np.array_equal(value, value_b) + return is_equal + + expected_output = ([{ + 'x': np.array([3, 2, 0, 0, 0]), + 'x_tfidf_weight': np.array([0.6, 0.28109303, 0.28109303], + dtype=np.float32), + 'x_vocab_index': np.array([0, 2, 3], dtype=np.int64) + }, + { + 'x': np.array([1, 1, 0]), + 'x_tfidf_weight': np.array( + [0.33333334, 0.9369768], dtype=np.float32), + 'x_vocab_index': np.array([0, 1], dtype=np.int32) + }]) + assert_that(actual_output, equal_to(expected_output, equals_fn=equals_fn)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py new file mode 100644 index 000000000000..1f1fa729b160 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ['ArtifactsFetcher'] + +import typing + +import tensorflow_transform as tft + + +class ArtifactsFetcher(): + """ + Utility class used to fetch artifacts from the artifact_location passed + to the TFTProcessHandlers in MLTransform. + """ + def __init__(self, artifact_location): + self.artifact_location = artifact_location + self.transform_output = tft.TFTransformOutput(self.artifact_location) + + def get_vocab_list( + self, + vocab_filename: str = 'compute_and_apply_vocab') -> typing.List[bytes]: + """ + Returns list of vocabulary terms created during MLTransform. + """ + try: + vocab_list = self.transform_output.vocabulary_by_name(vocab_filename) + except ValueError as e: + raise ValueError( + 'Vocabulary file {} not found in artifact location'.format( + vocab_filename)) from e + return [x.decode('utf-8') for x in vocab_list] + + def get_vocab_filepath( + self, vocab_filename: str = 'compute_and_apply_vocab') -> str: + """ + Return the path to the vocabulary file created during MLTransform. + """ + return self.transform_output.vocabulary_file_by_name(vocab_filename) + + def get_vocab_size(self, vocab_filename: str) -> int: + return self.transform_output.vocabulary_size_by_name(vocab_filename) diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 11bb131e88b7..06ea4dc3622c 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -131,7 +131,7 @@ release = version autoclass_content = 'both' autodoc_inherit_docstrings = False autodoc_member_order = 'bysource' -autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub"] +autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", "tensorflow_transform", "tensorflow_metadata"] # Allow a special section for documenting DataFrame API napoleon_custom_sections = ['Differences from pandas'] diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle index a96e1c5c9f6b..5b5b48298bd3 100644 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ b/sdks/python/test-suites/tox/py38/build.gradle @@ -116,6 +116,10 @@ toxTask "testPy38pytorch-200", "py38-pytorch-200", "${posargs}" test.dependsOn "testPy38pytorch-200" preCommitPyCoverage.dependsOn "testPy38pytorch-200" +toxTask "testPy38tft-113", "py38-tft-113", "${posargs}" +test.dependsOn "testPy38tft-113" +preCommitPyCoverage.dependsOn "testPy38tft-113" + // TODO(https://github.com/apache/beam/issues/25796) - uncomment onnx tox task once onnx supports protobuf 4.x.x // Create a test task for each minor version of onnx // toxTask "testPy38onnx-113", "py38-onnx-113", "${posargs}" diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index b2f784aada50..3f1b32a20d22 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -324,6 +324,12 @@ commands = # Run all DataFrame API unit tests bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/dataframe' +[testenv:py{38,39}-tft-113] +deps = + 113: tensorflow_transform>=1.13.0,<1.14.0 +commands = + bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/ml/transforms' + [testenv:py{38,39,310,311}-pytorch-{19,110,111,112,113}] deps = -r build-requirements.txt