diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index 23b9beadd336..9c2762dff710 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -20,6 +20,7 @@ # pytype: skip-file import collections.abc +import enum import sys import typing import unittest @@ -54,6 +55,11 @@ class _TestPair(typing.NamedTuple('TestTuple', [('first', T), ('second', T)]), pass +class _TestEnum(enum.Enum): + FOO = enum.auto() + BAR = enum.auto() + + class NativeTypeCompatibilityTest(unittest.TestCase): def test_convert_to_beam_type(self): test_cases = [ @@ -106,6 +112,7 @@ def test_convert_to_beam_type(self): typehints.List[_TestGeneric[int]]), ('nested generic with any', typing.List[_TestPair[typing.Any]], typehints.List[_TestPair[typing.Any]]), + ('raw enum', _TestEnum, _TestEnum), ] for test_case in test_cases: @@ -122,20 +129,22 @@ def test_convert_to_beam_type(self): def test_convert_to_beam_type_with_builtin_types(self): if sys.version_info >= (3, 9): - test_cases = [('builtin dict', dict[str, int], typehints.Dict[str, int]), - ('builtin list', list[str], typehints.List[str]), - ('builtin tuple', tuple[str], typehints.Tuple[str]), - ('builtin set', set[str], typehints.Set[str]), - ( - 'nested builtin', - dict[str, list[tuple[float]]], - typehints.Dict[str, - typehints.List[typehints.Tuple[float]]]), - ( - 'builtin nested tuple', - tuple[str, list], - typehints.Tuple[str, typehints.List[typehints.Any]], - )] + test_cases = [ + ('builtin dict', dict[str, int], typehints.Dict[str, int]), + ('builtin list', list[str], typehints.List[str]), + ('builtin tuple', tuple[str], + typehints.Tuple[str]), ('builtin set', set[str], typehints.Set[str]), + ('builtin frozenset', frozenset[int], typehints.FrozenSet[int]), + ( + 'nested builtin', + dict[str, list[tuple[float]]], + typehints.Dict[str, typehints.List[typehints.Tuple[float]]]), + ( + 'builtin nested tuple', + tuple[str, list], + typehints.Tuple[str, typehints.List[typehints.Any]], + ) + ] for test_case in test_cases: description = test_case[0] @@ -173,6 +182,14 @@ def test_convert_to_beam_type_with_collections_types(self): collections.abc.Mapping[str, int]), ('set', collections.abc.Set[str], typehints.Set[str]), ('mutable set', collections.abc.MutableSet[int], typehints.Set[int]), + ( + 'enum set', + collections.abc.Set[_TestEnum], + typehints.Set[_TestEnum]), + ( + 'enum mutable set', + collections.abc.MutableSet[_TestEnum], + typehints.Set[_TestEnum]) ] for test_case in test_cases: