Skip to content

Commit

Permalink
add missing bindings for State<T> abstract values. (needed for implem…
Browse files Browse the repository at this point in the history
…enting unrestricted updates from python) (#11416)
  • Loading branch information
RussTedrake authored May 8, 2019
1 parent b6190fb commit cabd2ab
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 10 deletions.
46 changes: 40 additions & 6 deletions bindings/pydrake/systems/framework_py_semantics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ void DefineFrameworkPySemantics(py::module m) {
.def(py::init<>(), doc.AbstractValues.ctor.doc_0args)
.def(py::init<AbstractValuePtrList>(), doc.AbstractValues.ctor.doc_1args)
.def("size", &AbstractValues::size, doc.AbstractValues.size.doc)
.def("get_value", &AbstractValues::get_value, py_reference_internal,
doc.AbstractValues.get_value.doc)
.def("get_value", &AbstractValues::get_value, py::arg("index"),
py_reference_internal, doc.AbstractValues.get_value.doc)
.def("get_mutable_value", &AbstractValues::get_mutable_value,
py_reference_internal, doc.AbstractValues.get_mutable_value.doc)
py::arg("index"), py_reference_internal,
doc.AbstractValues.get_mutable_value.doc)
.def("CopyFrom",
[](AbstractValues* self, const AbstractValues& other) {
WarnDeprecated(
Expand Down Expand Up @@ -309,7 +310,8 @@ void DefineFrameworkPySemantics(py::module m) {
[](const Context<T>* self, int index) -> auto& {
return self->get_abstract_state().get_value(index);
},
py_reference_internal, doc.Context.get_abstract_state.doc_1args)
py::arg("index"), py_reference_internal,
doc.Context.get_abstract_state.doc_1args)
.def("get_mutable_abstract_state",
[](Context<T>* self) -> AbstractValues& {
return self->get_mutable_abstract_state();
Expand All @@ -321,7 +323,7 @@ void DefineFrameworkPySemantics(py::module m) {
return self->get_mutable_abstract_state().get_mutable_value(
index);
},
py_reference_internal,
py::arg("index"), py_reference_internal,
doc.Context.get_mutable_abstract_state.doc_1args)
.def("SetAbstractState",
[](py::object self, int index, py::object value) {
Expand Down Expand Up @@ -634,7 +636,39 @@ void DefineFrameworkPySemantics(py::module m) {
.def("get_mutable_discrete_state",
overload_cast_explicit<DiscreteValues<T>&>(
&State<T>::get_mutable_discrete_state),
py_reference_internal, doc.State.get_mutable_discrete_state.doc);
py_reference_internal, doc.State.get_mutable_discrete_state.doc)
.def("get_discrete_state",
overload_cast_explicit<const BasicVector<T>&, int>(
&State<T>::get_discrete_state),
py::arg("index"), py_reference_internal,
doc.State.get_discrete_state.doc)
.def("get_mutable_discrete_state",
overload_cast_explicit<BasicVector<T>&, int>(
&State<T>::get_mutable_discrete_state),
py::arg("index"), py_reference_internal,
doc.State.get_mutable_discrete_state.doc)
.def("get_abstract_state",
static_cast<const AbstractValues& (State<T>::*)() const>(
&State<T>::get_abstract_state),
py_reference_internal, doc.State.get_abstract_state.doc)
.def("get_mutable_abstract_state",
[](State<T>* self) -> AbstractValues& {
return self->get_mutable_abstract_state();
},
py_reference_internal, doc.State.get_mutable_abstract_state.doc)
.def("get_abstract_state",
[](const State<T>* self, int index) -> auto& {
return self->get_abstract_state().get_value(index);
},
py::arg("index"), py_reference_internal,
doc.State.get_abstract_state.doc)
.def("get_mutable_abstract_state",
[](State<T>* self, int index) -> AbstractValue& {
return self->get_mutable_abstract_state().get_mutable_value(
index);
},
py::arg("index"), py_reference_internal,
doc.State.get_mutable_abstract_state.doc);

// - Constituents.
DefineTemplateClassWithDefault<ContinuousState<T>>(
Expand Down
24 changes: 20 additions & 4 deletions bindings/pydrake/systems/test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,28 @@ def test_context_api(self):
np.testing.assert_equal(
context.get_discrete_state_vector().CopyToVector(), 3 * x)

def check_abstract_value_zero(context, expected_value):
# Check through Context, State, and AbstractValues APIs.
self.assertEqual(context.get_abstract_state(index=0).get_value(),
expected_value)
self.assertEqual(context.get_abstract_state().get_value(
index=0).get_value(), expected_value)
self.assertEqual(context.get_state().get_abstract_state(
index=0).get_value(), expected_value)
self.assertEqual(context.get_state().get_abstract_state()
.get_value(index=0).get_value(), expected_value)

context.SetAbstractState(index=0, value=True)
value = context.get_abstract_state(0)
self.assertTrue(value.get_value())
check_abstract_value_zero(context, True)
context.SetAbstractState(index=0, value=False)
value = context.get_abstract_state(0)
self.assertFalse(value.get_value())
check_abstract_value_zero(context, False)
value = context.get_mutable_state().get_mutable_abstract_state(index=0)
value.set_value(True)
check_abstract_value_zero(context, True)
value = context.get_mutable_state().get_mutable_abstract_state()\
.get_mutable_value(index=0)
value.set_value(False)
check_abstract_value_zero(context, False)

def test_event_api(self):
# TriggerType - existence check.
Expand Down

0 comments on commit cabd2ab

Please sign in to comment.