Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 218 additions & 16 deletions sdks/python/apache_beam/yaml/examples/testing/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,35 @@
from apache_beam.yaml.readme_test import replace_recursive


# Used to simulate Enrichment transform during tests
# The GitHub action that invokes these tests does not
# have gcp dependencies installed which is a prerequisite
# to apache_beam.transforms.enrichment.Enrichment as a top-level
# import.
@beam.ptransform.ptransform_fn
def test_enrichment(
pcoll,
enrichment_handler: str,
handler_config: Dict[str, Any],
timeout: Optional[float] = 30):
"""
Mocks the Enrichment transform for testing purposes.

This PTransform simulates the behavior of the Enrichment transform by
looking up data from predefined in-memory tables based on the provided
`enrichment_handler` and `handler_config`.

Note: The Github action that invokes these tests does not have gcp
dependencies installed which is a prerequisite to
apache_beam.transforms.enrichment.Enrichment as a top-level import.

Args:
pcoll: The input PCollection.
enrichment_handler: A string indicating the type of enrichment handler
to simulate (e.g., 'BigTable', 'BigQuery').
handler_config: A dictionary containing configuration details for the
simulated handler (e.g., table names, row keys, fields).
timeout: An optional timeout value (ignored in this mock).

Returns:
A PCollection containing the enriched data.
"""

if enrichment_handler == 'BigTable':
row_key = handler_config['row_key']
bt_data = INPUT_TABLES[(
Expand Down Expand Up @@ -97,6 +115,20 @@ def _fn(row):


def check_output(expected: List[str]):
"""
Helper function to check the output of a pipeline against expected values.

This function takes a list of expected output strings and returns a
callable that can be used within a Beam pipeline to assert that the
actual output matches the expected output.

Args:
expected: A list of strings representing the expected output elements.

Returns:
A callable that takes a list of PCollections and asserts their combined
elements match the expected output.
"""
def _check_inner(actual: List[PCollection[str]]):
formatted_actual = actual | beam.Flatten() | beam.Map(
lambda row: str(beam.Row(**row._asdict())))
Expand Down Expand Up @@ -232,6 +264,23 @@ def bigquery_data():
def create_test_method(
pipeline_spec_file: str,
custom_preprocessors: List[Callable[..., Union[Dict, List]]]):
"""
Generates a test method for a given YAML pipeline specification file.

This function reads the YAML file, extracts the expected output (if present),
and creates a test function that uses `TestPipeline` to run the pipeline
defined in the YAML file. It also applies any custom preprocessors registered
for this test.

Args:
pipeline_spec_file: The path to the YAML file containing the pipeline
specification.
custom_preprocessors: A list of preprocessor functions to apply before
running the test.

Returns:
A test method (Callable) that can be added to a unittest.TestCase class.
"""
@mock.patch('apache_beam.Pipeline', TestPipeline)
def test_yaml_example(self):
with open(pipeline_spec_file, encoding="utf-8") as f:
Expand Down Expand Up @@ -294,16 +343,55 @@ def test_yaml_example(self):


class YamlExamplesTestSuite:
"""
YamlExamplesTestSuites class is used to scan specified directories for .yaml
files and dynamically generate a Python test method. Additionally, it creates
a method to complete some preprocessing for mocking IO.
"""
_test_preprocessor: Dict[str, List[Callable[..., Union[Dict, List]]]] = {}

def __init__(self, name: str, path: str):
"""
Initializes the YamlExamplesTestSuite.

Args:
name: The name of the test suite. This will be used as the class name
for the dynamically generated test suite.
path: A string representing the path or glob pattern to search for
YAML example files.
"""
self._test_suite = self.create_test_suite(name, path)

def run(self):
"""
Runs the dynamically generated test suite.

This method simply returns the test suite class created during
initialization. The test runner (e.g., unittest.main()) can then be used
to discover and run the tests within this suite.

Returns:
The dynamically created unittest.TestCase subclass.
"""
return self._test_suite

@classmethod
def parse_test_methods(cls, path: str):
"""Scans a given path for YAML files and generates test methods.

This method uses glob to find files matching the provided path. For each
YAML file found, it constructs a unique test name and then calls
`create_test_method` to generate the actual test function.
It also retrieves any registered preprocessors for that specific test.

Args:
path: A string representing the path or glob pattern to search for
YAML example files.

Yields:
A tuple containing the generated test name (str) and the
corresponding test method (Callable).
"""
files = glob.glob(path)
if not files and os.path.exists(path) and os.path.isfile(path):
files = [path]
Expand All @@ -314,11 +402,44 @@ def parse_test_methods(cls, path: str):

@classmethod
def create_test_suite(cls, name: str, path: str):
"""Dynamically creates a unittest.TestCase subclass with generated tests.

This method takes a suite name and a path (or glob pattern). It uses
`parse_test_methods` to find YAML files at the given path and generate
individual test methods for each. These generated test methods are then
added as attributes to a new class, which is a subclass of
`unittest.TestCase`.

Args:
name: The desired name for the dynamically created test suite class.
path: A string representing the path or glob pattern to search for
YAML example files, which will be used to generate test methods.

Returns:
A new class, subclass of `unittest.TestCase`, containing dynamically
generated test methods based on the YAML files found at the given path.
"""
return type(
name, (unittest.TestCase, ), dict(cls.parse_test_methods(path)))

@classmethod
def register_test_preprocessor(cls, test_names: Union[str, List]):
"""Decorator to register a preprocessor function for specific tests.

This decorator is used to associate a preprocessor function with one or
more test names. The preprocessor function will be called before the
corresponding test is executed, allowing for modification of the test
specification or environment setup.

Args:
test_names: A string or a list of strings representing the names of the
tests for which the preprocessor should be registered. The test names
should match the names generated by `parse_test_methods`.

Returns:
A decorator function that takes the preprocessor function as an argument
and registers it.
"""
if isinstance(test_names, str):
test_names = [test_names]

Expand All @@ -335,6 +456,23 @@ def apply(preprocessor):
@YamlExamplesTestSuite.register_test_preprocessor('test_wordcount_minimal_yaml')
def _wordcount_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
"""
Preprocessor for the wordcount_minimal.yaml test.

This preprocessor generates a random input file based on the expected output
of the wordcount example. This allows the test to verify the pipeline's
correctness without relying on a fixed input file.

Args:
test_spec: The dictionary representation of the YAML pipeline specification.
expected: A list of strings representing the expected output of the
pipeline.
env: The TestEnvironment object providing utilities for creating temporary
files.

Returns:
The modified test_spec dictionary with the input file path replaced.
"""
all_words = []
for element in expected:
word = element.split('=')[1].split(',')[0].replace("'", '')
Expand Down Expand Up @@ -366,7 +504,23 @@ def _wordcount_test_preprocessor(
])
def _io_write_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

"""
Preprocessor for tests that involve writing to IO.

This preprocessor replaces any WriteTo transform with a LogForTesting
transform. This allows the test to verify the data being written without
actually writing to an external system.

Args:
test_spec: The dictionary representation of the YAML pipeline specification.
expected: A list of strings representing the expected output of the
pipeline.
env: The TestEnvironment object providing utilities for creating temporary
files.

Returns:
The modified test_spec dictionary with WriteTo transforms replaced.
"""
if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '').startswith('WriteTo'):
Expand All @@ -384,6 +538,21 @@ def _io_write_test_preprocessor(
['test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml'])
def _file_io_read_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
"""
This preprocessor replaces any ReadFrom transform with a Create transform
that reads from a predefined in-memory dictionary. This allows the test
to verify the pipeline's correctness without relying on external files.

Args:
test_spec: The dictionary representation of the YAML pipeline specification.
expected: A list of strings representing the expected output of the
pipeline.
env: The TestEnvironment object providing utilities for creating temporary
files.

Returns:
The modified test_spec dictionary with ReadFrom transforms replaced.
"""

if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
Expand All @@ -402,7 +571,24 @@ def _file_io_read_test_preprocessor(
['test_spanner_read_yaml', 'test_enrich_spanner_with_bigquery_yaml'])
def _spanner_io_read_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

"""
Preprocessor for tests that involve reading from Spanner.

This preprocessor replaces any ReadFromSpanner transform with a Create
transform that reads from a predefined in-memory dictionary. This allows
the test to verify the pipeline's correctness without relying on external
Spanner instances.

Args:
test_spec: The dictionary representation of the YAML pipeline specification.
expected: A list of strings representing the expected output of the
pipeline.
env: The TestEnvironment object providing utilities for creating temporary
files.

Returns:
The modified test_spec dictionary with ReadFromSpanner transforms replaced.
"""
if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '').startswith('ReadFromSpanner'):
Expand Down Expand Up @@ -436,6 +622,23 @@ def _spanner_io_read_test_preprocessor(
['test_bigtable_enrichment_yaml', 'test_enrich_spanner_with_bigquery_yaml'])
def _enrichment_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
"""
Preprocessor for tests that involve the Enrichment transform.

This preprocessor replaces the actual Enrichment transform with a mock
`TestEnrichment` transform. This allows the test to verify the pipeline's
correctness without requiring external services like BigTable or BigQuery.

Args:
test_spec: The dictionary representation of the YAML pipeline specification.
expected: A list of strings representing the expected output of the
pipeline.
env: The TestEnvironment object providing utilities for creating temporary
files.

Returns:
The modified test_spec dictionary with Enrichment transforms replaced.
"""
if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '').startswith('Enrichment'):
Expand All @@ -451,23 +654,22 @@ def _enrichment_test_preprocessor(
('BigTable', 'beam-test', 'bigtable-enrichment-test'): bigtable_data(),
('BigQuery', 'ALL_TEST', 'customers'): bigquery_data()
}

YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__))
ExamplesTest = YamlExamplesTestSuite(
'ExamplesTest', os.path.join(YAML_DOCS_DIR, '../*.yaml')).run()

ElementWiseTest = YamlExamplesTestSuite(
'ElementwiseExamplesTest',
os.path.join(YAML_DOCS_DIR, '../transforms/elementwise/*.yaml')).run()

AggregationTest = YamlExamplesTestSuite(
'AggregationExamplesTest',
os.path.join(YAML_DOCS_DIR, '../transforms/aggregation/*.yaml')).run()

BlueprintsTest = YamlExamplesTestSuite(
'BlueprintsExamplesTest',
os.path.join(YAML_DOCS_DIR, '../transforms/blueprints/*.yaml')).run()
ElementWiseTest = YamlExamplesTestSuite(
'ElementwiseExamplesTest',
os.path.join(YAML_DOCS_DIR, '../transforms/elementwise/*.yaml')).run()
ExamplesTest = YamlExamplesTestSuite(
'ExamplesTest', os.path.join(YAML_DOCS_DIR, '../*.yaml')).run()
IOTest = YamlExamplesTestSuite(
'IOExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/io/*.yaml')).run()

MLTest = YamlExamplesTestSuite(
'MLExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/ml/*.yaml')).run()
Expand Down
Loading