Skip to content

Commit

Permalink
Add basic testing for yaml join docs.
Browse files Browse the repository at this point in the history
As we're using the fake SQL and not checking the outputs,
this is not substitute for real testing, but at the very
least it ensures the joins are syntactically correct.
  • Loading branch information
robertwb committed Aug 9, 2024
1 parent f73a6d1 commit e2bf5d6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
40 changes: 34 additions & 6 deletions sdks/python/apache_beam/yaml/readme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def expand(self, inputs):

def guess_name_and_type(expr):
expr = expr.strip().replace('`', '')
if expr.endswith('*'):
return 'unknown', str
parts = expr.split()
if len(parts) >= 2 and parts[-2].lower() == 'as':
name = parts[-1]
Expand Down Expand Up @@ -87,7 +89,7 @@ def guess_name_and_type(expr):
return name, typ

if m.group(1) == '*':
return inputs['PCOLLECTION'] | beam.Filter(lambda _: True)
return next(iter(inputs.values())) | beam.Filter(lambda _: True)
else:
output_schema = [
guess_name_and_type(expr) for expr in m.group(1).split(',')
Expand Down Expand Up @@ -280,17 +282,40 @@ def parse_test_methods(markdown_lines):
else:
if code_lines:
if code_lines[0].startswith('- type:'):
is_chain = not any('input:' in line for line in code_lines)
specs = yaml.load('\n'.join(code_lines), Loader=SafeLoader)
is_chain = not any('input' in spec for spec in specs)
if is_chain:
undefined_inputs = set(['input'])
else:

def extract_inputs(input_spec):
if not input_spec:
return set()
elif isinstance(input_spec, str):
return set([input_spec.split('.')[0]])
elif isinstance(input_spec, list):
return set.union(*[extract_inputs(v) for v in input_spec])
elif isinstance(input_spec, dict):
return set.union(
*[extract_inputs(v) for v in input_spec.values()])
else:
raise ValueError("Misformed inputs: " + input_spec)

def extract_name(input_spec):
return input_spec.get('name', input_spec.get('type'))

undefined_inputs = set.union(
*[extract_inputs(spec.get('input')) for spec in specs]) - set(
extract_name(spec) for spec in specs)
# Treat this as a fragment of a larger pipeline.
# pylint: disable=not-an-iterable
code_lines = [
'pipeline:',
' type: chain' if is_chain else '',
' transforms:',
' - type: ReadFromCsv',
' name: input',
' config:',
' path: whatever',
] + [
' - {type: ReadFromCsv, name: "%s", config: {path: x}}' %
undefined_input for undefined_input in undefined_inputs
] + [' ' + line for line in code_lines]
if code_lines[0] == 'pipeline:':
yaml_pipeline = '\n'.join(code_lines)
Expand Down Expand Up @@ -329,6 +354,9 @@ def createTestSuite(name, path):
InlinePythonTest = createTestSuite(
'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md'))

JoinTest = createTestSuite(
'JoinTest', os.path.join(YAML_DOCS_DIR, 'yaml-join.md'))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--render_dir', default=None)
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def _is_connected(edge_list, expected_node_count):
def _SqlJoinTransform(
pcolls,
sql_transform_constructor,
type: Union[str, Dict[str, List]],
*,
equalities: Union[str, List[Dict[str, str]]],
type: Union[str, Dict[str, List]] = 'inner',
fields: Optional[Dict[str, Any]] = None):
"""Joins two or more inputs using a specified condition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ inputs, one can use the following shorthand syntax:
input2: Second Input
input3: Third Input
config:
equalities: col
equalities: col1
```

## Join Types
Expand Down

0 comments on commit e2bf5d6

Please sign in to comment.