Skip to content

Commit 596da88

Browse files
pydrake systems: Expose initial round of different scalar types
1 parent daedecb commit 596da88

12 files changed

+1026
-754
lines changed

bindings/pydrake/systems/BUILD.bazel

+11-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ drake_py_library(
4444
drake_cc_library(
4545
name = "systems_pybind",
4646
hdrs = ["systems_pybind.h"],
47-
deps = ["//bindings/pydrake/util:cpp_template_pybind"],
47+
deps = [
48+
"//bindings/pydrake:autodiff_types_pybind",
49+
"//bindings/pydrake:symbolic_types_pybind",
50+
"//bindings/pydrake/util:cpp_template_pybind",
51+
],
4852
)
4953

5054
drake_pybind_library(
@@ -68,13 +72,16 @@ drake_pybind_library(
6872
package_info = PACKAGE_INFO,
6973
py_deps = [
7074
":module_py",
75+
"//bindings/pydrake:autodiffutils_py",
76+
"//bindings/pydrake:symbolic_py",
7177
"//bindings/pydrake/util:cpp_template_py",
7278
],
7379
)
7480

7581
drake_pybind_library(
7682
name = "primitives_py",
7783
cc_deps = [
84+
":systems_pybind",
7885
"//bindings/pydrake/util:drake_optional_pybind",
7986
],
8087
cc_so_name = "primitives",
@@ -88,6 +95,9 @@ drake_pybind_library(
8895

8996
drake_pybind_library(
9097
name = "analysis_py",
98+
cc_deps = [
99+
":systems_pybind",
100+
],
91101
cc_so_name = "analysis",
92102
cc_srcs = ["analysis_py.cc"],
93103
package_info = PACKAGE_INFO,

bindings/pydrake/systems/analysis_py.cc

+46-38
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "pybind11/pybind11.h"
22

33
#include "drake/bindings/pydrake/pydrake_pybind.h"
4+
#include "drake/bindings/pydrake/systems/systems_pybind.h"
45
#include "drake/systems/analysis/integrator_base.h"
56
#include "drake/systems/analysis/simulator.h"
67

@@ -15,44 +16,51 @@ PYBIND11_MODULE(analysis, m) {
1516

1617
m.doc() = "Bindings for the analysis portion of the Systems framework.";
1718

18-
using T = double;
19-
20-
py::class_<IntegratorBase<T>>(m, "IntegratorBase")
21-
.def("set_fixed_step_mode", &IntegratorBase<T>::set_fixed_step_mode)
22-
.def("get_fixed_step_mode", &IntegratorBase<T>::get_fixed_step_mode)
23-
.def("set_target_accuracy", &IntegratorBase<T>::set_target_accuracy)
24-
.def("get_target_accuracy", &IntegratorBase<T>::get_target_accuracy)
25-
.def("set_maximum_step_size", &IntegratorBase<T>::set_maximum_step_size)
26-
.def("get_maximum_step_size", &IntegratorBase<T>::get_maximum_step_size)
27-
.def("set_requested_minimum_step_size",
28-
&IntegratorBase<T>::set_requested_minimum_step_size)
29-
.def("get_requested_minimum_step_size",
30-
&IntegratorBase<T>::get_requested_minimum_step_size)
31-
.def("set_throw_on_minimum_step_size_violation",
32-
&IntegratorBase<T>::set_throw_on_minimum_step_size_violation)
33-
.def("get_throw_on_minimum_step_size_violation",
34-
&IntegratorBase<T>::get_throw_on_minimum_step_size_violation);
35-
36-
py::class_<Simulator<T>>(m, "Simulator")
37-
.def(py::init<const System<T>&>(),
38-
// Keep alive, reference: `self` keeps `System` alive.
39-
py::keep_alive<1, 2>())
40-
.def(py::init<const System<T>&, unique_ptr<Context<T>>>(),
41-
// Keep alive, reference: `self` keeps `System` alive.
42-
py::keep_alive<1, 2>(),
43-
// Keep alive, ownership: `Context` keeps `self` alive.
44-
py::keep_alive<3, 1>())
45-
.def("Initialize", &Simulator<T>::Initialize)
46-
.def("StepTo", &Simulator<T>::StepTo)
47-
.def("get_context", &Simulator<T>::get_context, py_reference_internal)
48-
.def("get_integrator", &Simulator<T>::get_integrator, py_reference_internal)
49-
.def("get_mutable_integrator", &Simulator<T>::get_mutable_integrator,
50-
py_reference_internal)
51-
.def("get_mutable_context", &Simulator<T>::get_mutable_context,
52-
py_reference_internal)
53-
.def("set_publish_every_time_step",
54-
&Simulator<T>::set_publish_every_time_step)
55-
.def("set_target_realtime_rate", &Simulator<T>::set_target_realtime_rate);
19+
py::module::import("pydrake.systems.framework");
20+
21+
auto bind_scalar_types = [m](auto dummy) {
22+
using T = decltype(dummy);
23+
DefineTemplateClassWithDefault<IntegratorBase<T>>(
24+
m, "IntegratorBase", GetPyParam<T>())
25+
.def("set_fixed_step_mode", &IntegratorBase<T>::set_fixed_step_mode)
26+
.def("get_fixed_step_mode", &IntegratorBase<T>::get_fixed_step_mode)
27+
.def("set_target_accuracy", &IntegratorBase<T>::set_target_accuracy)
28+
.def("get_target_accuracy", &IntegratorBase<T>::get_target_accuracy)
29+
.def("set_maximum_step_size", &IntegratorBase<T>::set_maximum_step_size)
30+
.def("get_maximum_step_size", &IntegratorBase<T>::get_maximum_step_size)
31+
.def("set_requested_minimum_step_size",
32+
&IntegratorBase<T>::set_requested_minimum_step_size)
33+
.def("get_requested_minimum_step_size",
34+
&IntegratorBase<T>::get_requested_minimum_step_size)
35+
.def("set_throw_on_minimum_step_size_violation",
36+
&IntegratorBase<T>::set_throw_on_minimum_step_size_violation)
37+
.def("get_throw_on_minimum_step_size_violation",
38+
&IntegratorBase<T>::get_throw_on_minimum_step_size_violation);
39+
40+
DefineTemplateClassWithDefault<Simulator<T>>(
41+
m, "Simulator", GetPyParam<T>())
42+
.def(py::init<const System<T>&>(),
43+
// Keep alive, reference: `self` keeps `System` alive.
44+
py::keep_alive<1, 2>())
45+
.def(py::init<const System<T>&, unique_ptr<Context<T>>>(),
46+
// Keep alive, reference: `self` keeps `System` alive.
47+
py::keep_alive<1, 2>(),
48+
// Keep alive, ownership: `Context` keeps `self` alive.
49+
py::keep_alive<3, 1>())
50+
.def("Initialize", &Simulator<T>::Initialize)
51+
.def("StepTo", &Simulator<T>::StepTo)
52+
.def("get_context", &Simulator<T>::get_context, py_reference_internal)
53+
.def("get_integrator", &Simulator<T>::get_integrator,
54+
py_reference_internal)
55+
.def("get_mutable_integrator", &Simulator<T>::get_mutable_integrator,
56+
py_reference_internal)
57+
.def("get_mutable_context", &Simulator<T>::get_mutable_context,
58+
py_reference_internal)
59+
.def("set_publish_every_time_step",
60+
&Simulator<T>::set_publish_every_time_step)
61+
.def("set_target_realtime_rate", &Simulator<T>::set_target_realtime_rate);
62+
};
63+
type_visit(bind_scalar_types, pysystems::NonSymbolicScalarPack{});
5664
}
5765

5866
} // namespace pydrake

bindings/pydrake/systems/framework_py.cc

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ namespace pydrake {
99
PYBIND11_MODULE(framework, m) {
1010
m.doc() = "Bindings for the core Systems framework.";
1111

12+
// Import autodiff and symbolic modules so that their types are declared for
13+
// use as template parameters.
14+
py::module::import("pydrake.autodiffutils");
15+
py::module::import("pydrake.symbolic");
16+
1217
// Incorporate definitions as pieces (to speed up compilation).
1318
DefineFrameworkPySystems(m);
1419
DefineFrameworkPySemantics(m);

0 commit comments

Comments
 (0)