diff --git a/bindings/pydrake/common/BUILD.bazel b/bindings/pydrake/common/BUILD.bazel index c4da51d50a15..8e76fb2da93d 100644 --- a/bindings/pydrake/common/BUILD.bazel +++ b/bindings/pydrake/common/BUILD.bazel @@ -813,6 +813,17 @@ drake_py_unittest( ], ) +drake_py_unittest( + name = "yaml_typed_test", + data = [ + "//common/yaml:test/yaml_io_test_input_1.yaml", + ], + deps = [ + ":yaml_py", + "//bindings/pydrake/common/test_utilities:meta_py", + ], +) + add_pybind_coverage_data() add_lint_tests_pydrake() diff --git a/bindings/pydrake/common/test/yaml_test.py b/bindings/pydrake/common/test/yaml_test.py index 0e4d0bc46dd0..d73a7056b45b 100644 --- a/bindings/pydrake/common/test/yaml_test.py +++ b/bindings/pydrake/common/test/yaml_test.py @@ -11,6 +11,7 @@ class TestYaml(unittest.TestCase): + """Tests for the untyped yaml_load / yaml_dump functions.""" def test_via_file(self): filename = os.path.join(os.environ["TEST_TMPDIR"], "foo.yaml") diff --git a/bindings/pydrake/common/test/yaml_typed_test.py b/bindings/pydrake/common/test/yaml_typed_test.py new file mode 100644 index 000000000000..5ae99b44f8ee --- /dev/null +++ b/bindings/pydrake/common/test/yaml_typed_test.py @@ -0,0 +1,463 @@ +import dataclasses as dc +import math +from math import nan +from textwrap import dedent +import typing +import unittest + +import numpy as np + +from pydrake.common import FindResourceOrThrow +from pydrake.common.test_utilities.meta import ( + ValueParameterizedTest, + run_with_multiple_values, +) +from pydrake.common.yaml import yaml_load_typed + + +# To provide test coverage for all of the special cases of YAML loading, we'll +# define some dataclasses. These classes mimic +# drake/common/yaml/test/example_structs.h +# and should be roughly kept in sync with the definitions in that file. + + +@dc.dataclass +class FloatStruct: + value: float = nan + + +@dc.dataclass +class StringStruct: + value: str = "nominal_string" + + +@dc.dataclass +class AllScalarsStruct: + some_bool: bool = False + some_float: float = nan + some_int: int = 11 + some_str: str = "nominal_string" + + +@dc.dataclass +class ListStruct: + value: typing.List[float] = dc.field( + default_factory=lambda: list((nan,))) + + +@dc.dataclass +class MapStruct: + value: typing.Dict[str, float] = dc.field( + default_factory=lambda: dict(nominal_float=nan)) + + +@dc.dataclass +class InnerStruct: + inner_value: float = nan + + +@dc.dataclass +class OptionalStruct: + value: typing.Optional[float] = nan + + +@dc.dataclass +class OptionalStructNoDefault: + value: typing.Optional[float] = None + + +@dc.dataclass +class NumpyStruct: + # TODO(jwnimmer-tri) Once we drop support for Ubuntu 20.04 "Focal", then we + # can upgrade to numpy >= 1.21 as our minimum at which point we can use the + # numpy.typing module here to constrain the shape and/or dtype. + value: np.ndarray = dc.field( + default_factory=lambda: np.array([nan])) + + +@dc.dataclass +class OuterStruct: + outer_value: float = nan + inner_struct: InnerStruct = dc.field( + default_factory=lambda: InnerStruct()) + + +@dc.dataclass +class BigMapStruct: + value: typing.Mapping[str, OuterStruct] = dc.field( + default_factory=lambda: dict( + foo=OuterStruct( + outer_value=1.0, + inner_struct=InnerStruct(inner_value=2.0)))) + + +class TestYamlTypedRead(unittest.TestCase, + metaclass=ValueParameterizedTest): + """Detailed tests for the typed yaml_load function(s). + + This test class is the Python flavor of the C++ test suite at + drake/common/yaml/test/yaml_read_archive_test.cc + and should be roughly kept in sync with the test cases in that file. + """ + + def _all_typed_read_options( + sweep_allow_yaml_with_no_schema=(True, False), + sweep_allow_schema_with_no_yaml=(True, False), + sweep_retain_map_defaults=(True, False)): + """Returns the options matrix for our value-parameterized test cases. + """ + result = [] + for i in sweep_allow_yaml_with_no_schema: + for j in sweep_allow_schema_with_no_yaml: + for k in sweep_retain_map_defaults: + result.append(dict(options=dict( + allow_yaml_with_no_schema=i, + allow_schema_with_no_yaml=j, + retain_map_defaults=k, + ))) + return result + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_float(self, *, options): + cases = [ + ("0", 0.0), + ("1", 1.0), + ("-1", -1.0), + ("0.0", 0.0), + ("1.2", 1.2), + ("-1.2", -1.2), + ("3e4", 3e4), + ("3e-4", 3e-4), + ("5.6e7", 5.6e7), + ("5.6e-7", 5.6e-7), + ("-5.6e7", -5.6e7), + ("-5.6e-7", -5.6e-7), + ("3E4", 3e4), + ("3E-4", 3e-4), + ("5.6E7", 5.6e7), + ("5.6E-7", 5.6e-7), + ("-5.6E7", -5.6e7), + ("-5.6E-7", -5.6e-7), + ] + for value, expected in cases: + data = f"value: {value}" + x = yaml_load_typed(schema=FloatStruct, data=data, **options) + self.assertEqual(x.value, expected) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_float_missing(self, *, options): + if options["allow_schema_with_no_yaml"]: + x = yaml_load_typed(schema=FloatStruct, data="{}", + **options) + self.assertTrue(math.isnan(x.value), msg=repr(x.value)) + else: + with self.assertRaisesRegex(RuntimeError, ".*missing.*"): + yaml_load_typed(schema=FloatStruct, data="{}", + **options) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_all_scalars(self, *, options): + data = dedent(""" + some_bool: true + some_float: 101.0 + some_int: 102 + some_str: foo + """) + x = yaml_load_typed(schema=AllScalarsStruct, data=data, **options) + self.assertEqual(x.some_bool, True) + self.assertEqual(x.some_float, 101.0) + self.assertEqual(x.some_int, 102) + self.assertEqual(x.some_str, "foo") + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_list(self, *, options): + cases = [ + ("[1.0, 2.0, 3.0]", [1.0, 2.0, 3.0]), + ] + for value, expected in cases: + data = f"value: {value}" + x = yaml_load_typed(schema=ListStruct, data=data, **options) + self.assertEqual(x.value, expected) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_list_missing(self, *, options): + if options["allow_schema_with_no_yaml"]: + x = yaml_load_typed(schema=ListStruct, data="{}", **options) + self.assertTrue(len(x.value), 1) + self.assertTrue(math.isnan(x.value[0]), msg=repr(x.value)) + else: + with self.assertRaisesRegex(RuntimeError, ".*missing.*"): + yaml_load_typed(schema=ListStruct, data="{}", **options) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_map(self, *, options): + data = dedent(""" + value: + foo: 0.0 + bar: 1.0 + """) + x = yaml_load_typed(schema=MapStruct, data=data, **options) + expected = dict(foo=0.0, bar=1.0) + if options["retain_map_defaults"]: + expected.update(nominal_float=nan) + self.assertEqual(x.value, expected) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_big_map_append(self, *, options): + data = dedent(""" + value: + bar: + outer_value: 3.0 + inner_struct: + inner_value: 4.0 + """) + x = yaml_load_typed(schema=BigMapStruct, data=data, **options) + expected = dict(bar=OuterStruct(3.0, InnerStruct(4.0))) + if options["retain_map_defaults"]: + expected.update(foo=OuterStruct(1.0, InnerStruct(2.0))) + self.assertEqual(x.value, expected) + + @run_with_multiple_values(_all_typed_read_options( + # When False, the parser raises an exception not worth testing for. + sweep_allow_schema_with_no_yaml=[True])) + def test_read_big_map_merge_new_outer_value(self, *, options): + data = dedent(""" + value: + foo: + outer_value: 3.0 + """) + x = yaml_load_typed(schema=BigMapStruct, data=data, **options) + expected = dict(foo=OuterStruct(3.0)) + if options["retain_map_defaults"]: + expected["foo"].inner_struct.inner_value = 2.0 + self.assertEqual(x.value, expected) + + @run_with_multiple_values(_all_typed_read_options( + # When False, the parser raises an exception not worth testing for. + sweep_allow_schema_with_no_yaml=[True])) + def test_read_big_map_merge_new_inner_value(self, *, options): + data = dedent(""" + value: + foo: + inner_struct: + inner_value: 4.0 + """) + x = yaml_load_typed(schema=BigMapStruct, data=data, **options) + expected = dict(foo=OuterStruct(inner_struct=InnerStruct(4.0))) + if options["retain_map_defaults"]: + expected["foo"].outer_value = 1.0 + self.assertEqual(x.value, expected) + + @run_with_multiple_values(_all_typed_read_options( + # When False, the parser raises an exception not worth testing for. + sweep_allow_schema_with_no_yaml=[True])) + def test_read_big_map_merge_empty(self, *, options): + data = dedent(""" + value: + foo: {} + """) + x = yaml_load_typed(schema=BigMapStruct, data=data, **options) + expected = dict(foo=OuterStruct()) + if options["retain_map_defaults"]: + expected["foo"].outer_value = 1.0 + expected["foo"].inner_struct.inner_value = 2.0 + self.assertEqual(x.value, expected) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_map_missing(self, *, options): + if options["allow_schema_with_no_yaml"]: + x = yaml_load_typed(schema=MapStruct, data="{}", **options) + self.assertEqual(x.value, dict(nominal_float=nan)) + else: + with self.assertRaisesRegex(RuntimeError, ".*missing.*"): + yaml_load_typed(schema=MapStruct, data="{}", **options) + + # TODO(jwnimmer-tri) Add test cases similar to StdMapWithMergeKeys + # and StdMapWithBadMergeKey from the C++ YAML test suite. + + # TODO(jwnimmer-tri) Add test cases similar to StdMapDirectly and + # StdMapDirectlyWithDefaults from the C++ YAML test suite. + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_optional(self, *, options): + # The test case numbers here (1..12) reference the specification as + # documented in the C++ unit test yaml_read_archive_test.cc. + for schema, data, expected in ( + (OptionalStructNoDefault, "value: 1.0", 1.0), # Case 1, 2 + (OptionalStruct, "value: 1.0", 1.0), # Case 3, 4 + (OptionalStructNoDefault, "value:", None), # Case 5, 6 + (OptionalStruct, "value:", None), # Case 7, 8 + (OptionalStructNoDefault, "{}", None), # Case 9, 10 + (OptionalStruct, "{}", ( + nan if options["allow_schema_with_no_yaml"] # Case 12 + else None)), # Case 11 + ): + with self.subTest(data=data, schema=schema): + actual = yaml_load_typed(schema=schema, data=data, **options) + self.assertEqual(actual, schema(expected)) + if options["allow_yaml_with_no_schema"]: + if "value:" in data: + amended_data = "foo: bar\n" + data + else: + amended_data = "foo: bar" + actual = yaml_load_typed( + schema=schema, data=amended_data, **options) + self.assertEqual(actual, schema(expected)) + + # TODO(jwnimmer-tri) Add test cases similar to Variant and VariantMissing + # from the C++ YAML test suite. + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_np_vector(self, *, options): + data = "value: [1.0, 2.0, 3.0]" + expected = [1.0, 2.0, 3.0] + x = yaml_load_typed(schema=NumpyStruct, data=data, **options) + np.testing.assert_equal(x.value, np.array(expected), verbose=True) + + data = "value: [1.0]" + expected = [1.0] + x = yaml_load_typed(schema=NumpyStruct, data=data, **options) + np.testing.assert_equal(x.value, np.array(expected), verbose=True) + + data = "value: []" + expected = [] + x = yaml_load_typed(schema=NumpyStruct, data=data, **options) + np.testing.assert_equal(x.value, np.array(expected), verbose=True) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_np_matrix(self, *, options): + data = dedent(""" + value: + - [0.0, 1.0, 2.0, 3.0] + - [4.0, 5.0, 6.0, 7.0] + - [8.0, 9.0, 10.0, 11.0] + """) + expected = [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [8.0, 9.0, 10.0, 11.0], + ] + x = yaml_load_typed(schema=NumpyStruct, data=data, **options) + np.testing.assert_equal(x.value, np.array(expected), verbose=True) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_np_missing(self, *, options): + schema = NumpyStruct + data = "{}" + expected = [nan] + if options["allow_schema_with_no_yaml"]: + x = yaml_load_typed(schema=schema, data=data, **options) + np.testing.assert_equal(x.value, np.array(expected), verbose=True) + else: + with self.assertRaisesRegex(RuntimeError, ".*missing.*"): + yaml_load_typed(schema=schema, data=data, **options) + + @run_with_multiple_values(_all_typed_read_options()) + def test_read_nested(self, *, options): + data = dedent(""" + outer_value: 1.0 + inner_struct: + inner_value: 2.0 + """) + x = yaml_load_typed(schema=OuterStruct, data=data, **options) + expected = dict(foo=0.0, bar=1.0) + self.assertEqual(x, OuterStruct(1.0, InnerStruct(2.0))) + + # TODO(jwnimmer-tri) Add a test case similar to NestedWithMergeKeys from + # the C++ YAML test suite. + + # TODO(jwnimmer-tri) Add a test case similar to NestedWithBadMergeKey from + # the C++ YAML test suite. + + # TODO(jwnimmer-tri) Add a test cases similar to these from the C++ YAML + # test suite: + # - VisitScalarFoundNothing + # - VisitScalarFoundArray + # - VisitScalarFoundStruct + # - VisitArrayFoundNothing + # - VisitArrayFoundScalar + # - VisitArrayFoundStruct + # - VisitVectorFoundNothing + # - VisitVectorFoundScalar + # - VisitVectorFoundStruct + # - VisitOptionalScalarFoundSequence + # - VisitVariantFoundNoTag + # - VisitVariantFoundUnknownTag + # - VisitEigenFoundNothing + # - VisitEigenFoundScalar + # - VisitEigenMatrixFoundOneDimensional + # - VisitEigenMatrixFoundNonSquare + # - VisitStructFoundNothing + # - VisitStructFoundScalar + # - VisitStructFoundArray + + +class TestYamlTypedReadAcceptance(unittest.TestCase): + """Acceptance tests for the typed yaml_load function(s). + + This test class is the Python flavor of the C++ test suite at + drake/common/yaml/test/yaml_io_test.cc + and should be roughly kept in sync with the test cases in that file. + """ + + def test_load_string(self): + data = dedent(""" + value: + some_value + """) + result = yaml_load_typed(schema=StringStruct, data=data) + self.assertEqual(result.value, "some_value") + + def test_load_string_child_name(self): + data = dedent(""" + some_child_name: + value: + some_value + """) + result = yaml_load_typed(schema=StringStruct, data=data, + child_name="some_child_name") + self.assertEqual(result.value, "some_value") + + # When the requested child_name does not exist, that's an error. + with self.assertRaisesRegex(KeyError, "wrong_child_name"): + yaml_load_typed(schema=StringStruct, data=data, + child_name="wrong_child_name") + + def test_load_string_defaults(self): + data = dedent(""" + value: + some_key: 1.0 + """) + defaults = MapStruct() + + # Merge the default map value(s). + result = yaml_load_typed( + schema=MapStruct, data=data, defaults=defaults) + self.assertDictEqual(result.value, dict( + nominal_float=nan, + some_key=1.0)) + + # Replace the default map value(s). + result = yaml_load_typed( + schema=MapStruct, data=data, defaults=defaults, + retain_map_defaults=False) + self.assertDictEqual(result.value, dict(some_key=1.0)) + + def test_load_string_options(self): + data = dedent(""" + value: some_value + extra_junk: will_be_ignored + """) + result = yaml_load_typed(schema=StringStruct, data=data, + allow_yaml_with_no_schema=True) + self.assertEqual(result.value, "some_value") + + # Cross-check that the option actually was important. + with self.assertRaisesRegex(RuntimeError, ".*extra_junk.*"): + yaml_load_typed(schema=StringStruct, data=data) + + def test_load_file(self): + filename = FindResourceOrThrow( + "drake/common/yaml/test/yaml_io_test_input_1.yaml") + result = yaml_load_typed(schema=StringStruct, filename=filename) + self.assertEqual(result.value, "some_value_1") diff --git a/bindings/pydrake/common/yaml.py b/bindings/pydrake/common/yaml.py index ab4eeccee94f..fe543e3a68fe 100644 --- a/bindings/pydrake/common/yaml.py +++ b/bindings/pydrake/common/yaml.py @@ -1,5 +1,10 @@ +import collections.abc import copy +import dataclasses +import functools +import typing +import numpy as np import yaml @@ -31,14 +36,18 @@ def _handle_multi_variant(loader, tag, node): def yaml_load_data(data, *, private=False): """Loads and returns the given `data` str as a yaml object, while also - accounting for variant-like type tags. The known variant yaml tags - are reported as an extra "_tag" field in the returned dictionary. + accounting for variant-like type tags. Any tags are reported as an + extra "_tag" field in the returned dictionary. (Alternatively, `data` may be a file-like stream instead of a str.) By default, removes any root-level keys that begin with an underscore, so that yaml anchors and templates are invisibly easy to use. Callers that wish to receive the private data may pass `private=True`. + + This function returns the raw, untyped data (dict, list, str, float, etc.) + without any schema checking nor default values. To load with respect to + a schema with defaults, see ``yaml_load_typed()``. """ result = yaml.load(data, Loader=_SchemaLoader) if not private: @@ -68,6 +77,10 @@ def yaml_load(*, data=None, filename=None, private=False): This is sugar for calling either yaml_load_data or yaml_load_file; refer to those functions for additional details. + + This function returns the raw, untyped data (dict, list, str, float, etc.) + without any schema checking nor default values. To load with respect to + a schema with defaults, see ``yaml_load_typed()``. """ if sum(bool(x) for x in [data, filename]) != 1: raise RuntimeError("Must specify exactly one of data= and filename=") @@ -114,3 +127,274 @@ def yaml_dump(data, *, filename=None): else: return yaml.dump(data, Dumper=_SchemaDumper, default_flow_style=_FLOW_STYLE) + + +_LoadYamlOptions = collections.namedtuple("LoadYamlOptions", [ + "allow_yaml_with_no_schema", + "allow_schema_with_no_yaml", + "retain_map_defaults", +]) + + +def _enumerate_field_types(schema): + """Returns a Mapping[str, type] of the schema-based field names and types + of the given type `schema`. + """ + assert isinstance(schema, type) + + # Dataclasses offer a public API for introspection. + if dataclasses.is_dataclass(schema): + return dict([ + (field.name, field.type) + for field in dataclasses.fields(schema)]) + + raise NotImplementedError( + f"Schema objects of type {schema} are not yet supported") + + +def _get_nested_optional_type(schema): + """If the given schema (i.e., type) is equivalent to an Optional[Foo], then + returns Foo. Otherwise, returns None. + """ + generic_base = typing.get_origin(schema) + if generic_base == typing.Union: + generic_args = typing.get_args(schema) + NoneType = type(None) + if len(generic_args) == 2 and generic_args[-1] == NoneType: + (nested_type, _) = generic_args + return nested_type + return None + + +def _merge_yaml_dict_item_into_target(*, options, name, yaml_value, + target, value_schema): + """Parses the given `yaml_value` into an object of type `value_schema`, + writing the result to the field named `name` of the given `target` object. + """ + # The target can be either a dictionary or a dataclass. + if isinstance(target, collections.abc.Mapping): + old_value = target[name] + setter = functools.partial(target.__setitem__, name) + else: + old_value = getattr(target, name) + setter = functools.partial(setattr, target, name) + + # Handle all of the plain YAML scalars: + # https://yaml.org/spec/1.2.2/#scalars + # https://yaml.org/spec/1.2.2/#json-schema + if value_schema in (bool, int, float, str): + new_value = value_schema(yaml_value) + setter(new_value) + return + + # Handle nullable types (std::optional or typing.Optional[T]). + nested_optional_type = _get_nested_optional_type(value_schema) + if nested_optional_type is not None: + # If the yaml was null, the Python field will be None. + if yaml_value is None: + setter(None) + return + # Create a non-null default value, if necessary. + if old_value is None: + setter(nested_optional_type()) + # Now we can parse Optional[Foo] like a plain Foo. + _merge_yaml_dict_item_into_target( + options=options, name=name, yaml_value=yaml_value, target=target, + value_schema=nested_optional_type) + return + + # Handle NumPy types. + if value_schema == np.ndarray: + new_value = np.array(yaml_value, dtype=float) + setter(new_value) + return + + # Check if the field is generic like list[str]; if yes, the generic_base + # will be, e.g., `list` and generic_args will be, e.g., `[str]`. + generic_base = typing.get_origin(value_schema) + generic_args = typing.get_args(value_schema) + + # Handle YAML sequences: + # https://yaml.org/spec/1.2.2/#sequence + # + # In Drake's YamlLoad convention, merging a sequence denotes *overwriting* + # what was there. + if generic_base in (list, typing.List): + (value_type,) = generic_args + new_value = [] + for sub_yaml_value in yaml_value: + sub_target = {"_": value_type()} + _merge_yaml_dict_item_into_target( + options=options, name="_", yaml_value=sub_yaml_value, + target=sub_target, value_schema=value_type) + new_value.append(sub_target["_"]) + setter(new_value) + return + + # Handle YAML maps: + # https://yaml.org/spec/1.2.2/#mapping + # + # In Drake's YamlLoad convention, merging a mapping denotes *updating* + # what was there iff retain_map_defaults was set. + if generic_base in (dict, collections.abc.Mapping): + (key_type, value_type) = generic_args + assert key_type == str + if options.retain_map_defaults: + new_value = copy.deepcopy(old_value) + else: + new_value = dict() + for sub_key, sub_yaml_value in yaml_value.items(): + if sub_key not in new_value: + new_value[sub_key] = value_type() + _merge_yaml_dict_item_into_target( + options=options, name=sub_key, yaml_value=sub_yaml_value, + target=new_value, value_schema=value_type) + setter(new_value) + return + + # Handle schema sum types (std::variant<...> or typing.Union[...]). + if generic_base is typing.Union: + # TODO(jwnimmer-tri) Implement me. + raise NotImplementedError("Union[] types are not yet supported") + + # By this point, we've handled all known cases of generic types. + if generic_base is not None: + raise NotImplementedError( + f"The generic type {generic_base} of {value_schema} is " + "not yet supported") + + # If the value_schema is neither primitive nor generic, then we'll assume + # it's a directly-nested subclass. + new_value = copy.deepcopy(old_value) + _merge_yaml_dict_into_target( + options=options, yaml_dict=yaml_value, + target=new_value, target_schema=value_schema) + setter(new_value) + + +def _merge_yaml_dict_into_target(*, options, yaml_dict, + target, target_schema): + """Merges the given yaml_dict into the given target (of given type). + The target must be an instance of some dataclass or pybind11 class. + The yaml_dict must be typed like the result of calling yaml_load (i.e., + raw strings, dictionaries, lists, etc.). + """ + assert isinstance(yaml_dict, collections.abc.Mapping), yaml_dict + if "_tag" in yaml_dict: + # TODO(jwnimmer-tri) Implement me. + raise NotImplementedError("Union[] type tags are not yet supported") + static_field_map = _enumerate_field_types(target_schema) + schema_names = list(static_field_map.keys()) + schema_optionals = set([ + name for name, sub_schema in static_field_map.items() + if _get_nested_optional_type(sub_schema) is not None + ]) + yaml_names = list(yaml_dict.keys()) + extra_yaml_names = [ + name for name in yaml_names + if name not in schema_names + ] + missing_yaml_names = [ + name for name, sub_schema in static_field_map.items() + if name not in yaml_names and name not in schema_optionals + ] + if extra_yaml_names and not options.allow_yaml_with_no_schema: + raise RuntimeError( + f"The fields {extra_yaml_names} were unknown to the schema") + if missing_yaml_names and not options.allow_schema_with_no_yaml: + raise RuntimeError( + f"The fields {missing_yaml_names} were missing in the yaml data") + for name, sub_schema in static_field_map.items(): + if name in yaml_dict: + sub_value = yaml_dict[name] + elif name in schema_optionals: + # For Optional fields that are missing from the yaml data, we must + # match the C++ heuristic: when "allow no yaml" is set, we'll leave + # the existing value unchanged. Otherwise, we need to affirmatively + # to set the target to the None value. + if options.allow_schema_with_no_yaml: + continue + sub_value = None + else: + # Errors for non-Optional missing yaml data have already been + # implemented above (see "missing_yaml_names"), so we should just + # skip over those fields here. They will remain unchanged. + continue + _merge_yaml_dict_item_into_target( + options=options, name=name, yaml_value=sub_value, + target=target, value_schema=sub_schema) + + +def yaml_load_typed(*, schema, + data=None, + filename=None, + child_name=None, + defaults=None, + allow_yaml_with_no_schema=False, + allow_schema_with_no_yaml=True, + retain_map_defaults=True): + """Loads either a ``data`` str or a ``filename`` against the given + ``schema`` type and returns an instance of that type. + + This mimics the C++ function ``drake::common::yaml::LoadYamlFile``. + + Args: + schema: The type to load. Must be a ``dataclass``. (Adding support for + more type classes is future work.) + data: The string of YAML data to be loaded. Exactly one of either + ``data`` or ``filename`` must be provided. + filename: The filename of YAML data to be loaded. Exactly one of either + ``data`` or ``filename`` must be provided. + child_name: If provided, loads data from given-named child of the + document's root instead of the root itself. + defaults: If provided, then the object being read into will be + initialized using this value instead of the schema's default + constructor. + allow_yaml_with_no_schema: Allows yaml Maps to have extra key-value + pairs that are specified by the schema being parsed into. In other + words, the schema argument provides only an incomplete schema for + the YAML data. This allows for parsing only a subset of the YAML + data. + allow_schema_with_no_yaml: Allows the schema to provide more key-value + pairs than are present in the YAML data. In other words, objects + can have default values that are left intact unless the YAML data + provides a value. + retain_map_defaults: If set to true, when parsing a Mapping the loader + will merge the YAML data into the destination, instead of replacing + the dict contents entirely. In other words, a Mapping field in a + schema can have default values that are left intact unless the YAML + data provides a value *for that specific key*. + """ + # Choose the allow/retain setting in case none were provided. + options = _LoadYamlOptions( + allow_yaml_with_no_schema=allow_yaml_with_no_schema, + allow_schema_with_no_yaml=allow_schema_with_no_yaml, + retain_map_defaults=retain_map_defaults) + + # Create the result object. + if defaults is not None: + result = copy.deepcopy(defaults) + else: + result = schema() + + # Parse the YAML document. + document = yaml_load(data=data, filename=filename) + if child_name is not None: + root_node = document[child_name] + else: + root_node = document + + # Merge the document into the result. + _merge_yaml_dict_into_target( + options=options, yaml_dict=root_node, + target=result, target_schema=schema) + return result + + +__all__ = [ + "yaml_dump", + "yaml_load", + "yaml_load_data", + "yaml_load_file", + "yaml_load_typed", +] diff --git a/common/yaml/test/example_structs.h b/common/yaml/test/example_structs.h index 8746cf94c920..286ae6b7f1b7 100644 --- a/common/yaml/test/example_structs.h +++ b/common/yaml/test/example_structs.h @@ -17,6 +17,10 @@ namespace drake { namespace yaml { namespace test { +// These data structures are the C++ flavor of the Python test classes at +// drake/bindings/pydrake/common/test/yaml_typed_test.py +// and should be roughly kept in sync with the code in that file. + // A value used in the test data below to include a default (placeholder) value // when initializing struct data members. constexpr double kNominalDouble = 1.2345; diff --git a/common/yaml/test/yaml_read_archive_test.cc b/common/yaml/test/yaml_read_archive_test.cc index 4fed92766d94..04371c08765a 100644 --- a/common/yaml/test/yaml_read_archive_test.cc +++ b/common/yaml/test/yaml_read_archive_test.cc @@ -11,6 +11,10 @@ #include "drake/common/test_utilities/expect_throws_message.h" #include "drake/common/yaml/test/example_structs.h" +// This test suite is the C++ flavor of the Python test suite at +// drake/bindings/pydrake/common/test/yaml_typed_test.py +// and should be roughly kept in sync with the test cases in that file. + // TODO(jwnimmer-tri) All of these regexps would be better off using the // std::regex::basic grammar, where () and {} are not special characters.