From cabd2abd187ceae593510b93fade1ba8bb8ea654 Mon Sep 17 00:00:00 2001 From: Russ Tedrake Date: Wed, 8 May 2019 14:18:33 -0400 Subject: [PATCH] add missing bindings for State abstract values. (needed for implementing unrestricted updates from python) (#11416) --- .../pydrake/systems/framework_py_semantics.cc | 46 ++++++++++++++++--- bindings/pydrake/systems/test/general_test.py | 24 ++++++++-- 2 files changed, 60 insertions(+), 10 deletions(-) diff --git a/bindings/pydrake/systems/framework_py_semantics.cc b/bindings/pydrake/systems/framework_py_semantics.cc index 239b6e167f46..865003af0657 100644 --- a/bindings/pydrake/systems/framework_py_semantics.cc +++ b/bindings/pydrake/systems/framework_py_semantics.cc @@ -100,10 +100,11 @@ void DefineFrameworkPySemantics(py::module m) { .def(py::init<>(), doc.AbstractValues.ctor.doc_0args) .def(py::init(), doc.AbstractValues.ctor.doc_1args) .def("size", &AbstractValues::size, doc.AbstractValues.size.doc) - .def("get_value", &AbstractValues::get_value, py_reference_internal, - doc.AbstractValues.get_value.doc) + .def("get_value", &AbstractValues::get_value, py::arg("index"), + py_reference_internal, doc.AbstractValues.get_value.doc) .def("get_mutable_value", &AbstractValues::get_mutable_value, - py_reference_internal, doc.AbstractValues.get_mutable_value.doc) + py::arg("index"), py_reference_internal, + doc.AbstractValues.get_mutable_value.doc) .def("CopyFrom", [](AbstractValues* self, const AbstractValues& other) { WarnDeprecated( @@ -309,7 +310,8 @@ void DefineFrameworkPySemantics(py::module m) { [](const Context* self, int index) -> auto& { return self->get_abstract_state().get_value(index); }, - py_reference_internal, doc.Context.get_abstract_state.doc_1args) + py::arg("index"), py_reference_internal, + doc.Context.get_abstract_state.doc_1args) .def("get_mutable_abstract_state", [](Context* self) -> AbstractValues& { return self->get_mutable_abstract_state(); @@ -321,7 +323,7 @@ void DefineFrameworkPySemantics(py::module m) { return self->get_mutable_abstract_state().get_mutable_value( index); }, - py_reference_internal, + py::arg("index"), py_reference_internal, doc.Context.get_mutable_abstract_state.doc_1args) .def("SetAbstractState", [](py::object self, int index, py::object value) { @@ -634,7 +636,39 @@ void DefineFrameworkPySemantics(py::module m) { .def("get_mutable_discrete_state", overload_cast_explicit&>( &State::get_mutable_discrete_state), - py_reference_internal, doc.State.get_mutable_discrete_state.doc); + py_reference_internal, doc.State.get_mutable_discrete_state.doc) + .def("get_discrete_state", + overload_cast_explicit&, int>( + &State::get_discrete_state), + py::arg("index"), py_reference_internal, + doc.State.get_discrete_state.doc) + .def("get_mutable_discrete_state", + overload_cast_explicit&, int>( + &State::get_mutable_discrete_state), + py::arg("index"), py_reference_internal, + doc.State.get_mutable_discrete_state.doc) + .def("get_abstract_state", + static_cast::*)() const>( + &State::get_abstract_state), + py_reference_internal, doc.State.get_abstract_state.doc) + .def("get_mutable_abstract_state", + [](State* self) -> AbstractValues& { + return self->get_mutable_abstract_state(); + }, + py_reference_internal, doc.State.get_mutable_abstract_state.doc) + .def("get_abstract_state", + [](const State* self, int index) -> auto& { + return self->get_abstract_state().get_value(index); + }, + py::arg("index"), py_reference_internal, + doc.State.get_abstract_state.doc) + .def("get_mutable_abstract_state", + [](State* self, int index) -> AbstractValue& { + return self->get_mutable_abstract_state().get_mutable_value( + index); + }, + py::arg("index"), py_reference_internal, + doc.State.get_mutable_abstract_state.doc); // - Constituents. DefineTemplateClassWithDefault>( diff --git a/bindings/pydrake/systems/test/general_test.py b/bindings/pydrake/systems/test/general_test.py index 7d0335dae539..53211fc1f4a3 100644 --- a/bindings/pydrake/systems/test/general_test.py +++ b/bindings/pydrake/systems/test/general_test.py @@ -164,12 +164,28 @@ def test_context_api(self): np.testing.assert_equal( context.get_discrete_state_vector().CopyToVector(), 3 * x) + def check_abstract_value_zero(context, expected_value): + # Check through Context, State, and AbstractValues APIs. + self.assertEqual(context.get_abstract_state(index=0).get_value(), + expected_value) + self.assertEqual(context.get_abstract_state().get_value( + index=0).get_value(), expected_value) + self.assertEqual(context.get_state().get_abstract_state( + index=0).get_value(), expected_value) + self.assertEqual(context.get_state().get_abstract_state() + .get_value(index=0).get_value(), expected_value) + context.SetAbstractState(index=0, value=True) - value = context.get_abstract_state(0) - self.assertTrue(value.get_value()) + check_abstract_value_zero(context, True) context.SetAbstractState(index=0, value=False) - value = context.get_abstract_state(0) - self.assertFalse(value.get_value()) + check_abstract_value_zero(context, False) + value = context.get_mutable_state().get_mutable_abstract_state(index=0) + value.set_value(True) + check_abstract_value_zero(context, True) + value = context.get_mutable_state().get_mutable_abstract_state()\ + .get_mutable_value(index=0) + value.set_value(False) + check_abstract_value_zero(context, False) def test_event_api(self): # TriggerType - existence check.