Skip to content

Commit

Permalink
simulate python: glfw linkage no longer needed, C++ physics loop reused
Browse files Browse the repository at this point in the history
this makes mjsimulate a shared library (it has to be for glfw linkage to be
hidden)

part of the Python's bindings callback module has a header now
  • Loading branch information
aftersomemath committed Aug 25, 2022
1 parent b9372cc commit 1fd9e0e
Show file tree
Hide file tree
Showing 11 changed files with 695 additions and 816 deletions.
22 changes: 16 additions & 6 deletions python/mujoco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ add_compile_options("${MUJOCO_HARDEN_COMPILE_OPTIONS}")
add_link_options("${MUJOCO_HARDEN_LINK_OPTIONS}")

find_package(Python3 COMPONENTS Interpreter Development)
find_package(glfw3 3.3 REQUIRED)

include(FindOrFetch)

Expand Down Expand Up @@ -266,7 +265,7 @@ target_link_libraries(errors_header INTERFACE crossplatform func_wrap mujoco)
add_library(raw INTERFACE)
target_sources(raw INTERFACE raw.h)
set_target_properties(raw PROPERTIES PUBLIC_HEADER raw.h)
target_link_libraries(raw INTERFACE mujoco mjsimulate)
target_link_libraries(raw INTERFACE mujoco)

add_library(structs_header INTERFACE)
target_sources(
Expand Down Expand Up @@ -334,17 +333,29 @@ macro(mujoco_pybind11_module name)
endif()
endmacro()

add_library(callbacks_header INTERFACE)
target_sources(callbacks_header INTERFACE callbacks.h)
set_target_properties(callbacks_header PROPERTIES PUBLIC_HEADER callbacks.h)
target_link_libraries(
callbacks_header
INTERFACE errors_header
structs_header
mujoco
raw
)

mujoco_pybind11_module(_callbacks callbacks.cc)
target_link_libraries(
_callbacks
PRIVATE errors_header
mujoco
raw
structs_header
callbacks_header
)

mujoco_pybind11_module(_constants constants.cc)
target_link_libraries(_constants PRIVATE mujoco mjsimulate)
target_link_libraries(_constants PRIVATE mujoco)

mujoco_pybind11_module(_enums enums.cc)
target_link_libraries(
Expand Down Expand Up @@ -411,8 +422,8 @@ target_link_libraries(
PRIVATE mjsimulate
mujoco
raw
glfw
structs_header)
structs_header
Eigen3::Eigen)
target_link_options(_simulate PRIVATE -Wl,-no-as-needed)

set(LIBRARIES_FOR_WHEEL
Expand Down Expand Up @@ -453,7 +464,6 @@ if(MUJOCO_PYTHON_MAKE_WHEEL)
_render
_rollout
_simulate
_structs
mujoco
mjsimulate
)
Expand Down
226 changes: 1 addition & 225 deletions python/mujoco/callbacks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,152 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstddef>
#include <cstdint>
#include <exception>
#include <limits>
#include <sstream>
#include <type_traits>

#include <mujoco/mujoco.h>
#include "errors.h"
#include "structs.h"
#include "raw.h"
#include <pybind11/eval.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include "callbacks.h"

namespace mujoco::python {
namespace {
namespace py = ::pybind11;

[[noreturn]] static void EscapeWithPythonException() {
mju_error("Python exception raised");
std::terminate(); // not actually reachable, mju_error doesn't return
}

template <typename T, typename U>
using enable_if_not_const_t =
std::enable_if_t<std::is_same_v<std::remove_const_t<T>, T>, U>;

// MuJoCo passes raw mjModel* and mjData* as arguments to callbacks, but Python
// callables expect the corresponding MjWrapper objects. To avoid creating new
// wrappers each time we enter callbacks, we instead maintain a global lookup
// table that associates raw MuJoCo struct pointers back to the pointers to
// their corresponding wrappers.
template <typename Raw>
static enable_if_not_const_t<Raw, py::handle> MjWrapperLookup(Raw* ptr) {
using LookupFnType = MjWrapper<Raw>* (Raw*);
static LookupFnType* const lookup = []() -> LookupFnType* {
py::gil_scoped_acquire gil;
auto m = py::module_::import("mujoco._structs");
pybind11::handle builtins(PyEval_GetBuiltins());
if (!builtins.contains(MjWrapper<Raw>::kFromRawPointer)) {
return nullptr;
} else {
try {
return reinterpret_cast<LookupFnType*>(
builtins[MjWrapper<Raw>::kFromRawPointer]
.template cast<std::uintptr_t>());
} catch (const py::cast_error&) {
return nullptr;
}
}
}();

MjWrapper<Raw>* wrapper = nullptr;
if (lookup) {
wrapper = lookup(ptr);
} else {
{
py::gil_scoped_acquire gil;
PyErr_SetString(
UnexpectedError::GetPyExc(),
"_structs module did not register its raw pointer lookup functions");
}
}

if (!wrapper) {
{
py::gil_scoped_acquire gil;
PyErr_SetString(
UnexpectedError::GetPyExc(),
"cannot find the corresponding wrapper for the raw mjStruct");
}
}

// Now we find the existing Python instance of our wrapper.
// TODO(stunya): Figure out a way to do this without invoking py::detail.
{
py::gil_scoped_acquire gil;
const auto [src, type] =
py::detail::type_caster_base<MjWrapper<Raw>>::src_and_type(wrapper);
if (type) {
py::handle instance =
py::detail::find_registered_python_instance(wrapper, type);
if (!instance) {
if (!PyErr_Occurred()) {
PyErr_SetString(
UnexpectedError::GetPyExc(),
"cannot find the Python instance of the MjWrapper");
}
} else {
return instance;
}
} else {
if (!PyErr_Occurred()) {
PyErr_SetString(
UnexpectedError::GetPyExc(),
"MjWrapper type isn't registered with pybind11");
}
}
}

EscapeWithPythonException();
}

template <typename Raw>
static const py::handle MjWrapperLookup(const Raw* ptr) {
return MjWrapperLookup(const_cast<Raw*>(ptr));
}

template <typename Return, typename... Args>
static Return
CallPyCallback(const char* name, PyObject* py_callback, Args... args) {
{
py::gil_scoped_acquire gil;
if (!py_callback) {
std::ostringstream msg;
msg << "py_" << name << " is null";
PyErr_SetString(UnexpectedError::GetPyExc(), msg.str().c_str());
} else {
py::handle callback(py_callback);
try {
if constexpr (std::is_void_v<Return>) {
callback(args...);
return;
} else {
return callback(args...).template cast<Return>();
}
} catch (py::error_already_set& e) {
e.restore();
} catch (const py::cast_error&) {
std::ostringstream msg;
msg << name << " callback did not return ";
if constexpr (std::is_integral_v<Return>) {
msg << "an integer";
} else if constexpr (std::is_floating_point_v<Return>) {
msg << "a number";
} else {
msg << "the correct type";
}
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
}
}
}
EscapeWithPythonException();
}

static PyObject* py_mju_user_warning = nullptr;
static void PyMjuUserWarning(const char* msg) {
CallPyCallback<void>("mju_user_warning", py_mju_user_warning, msg);
Expand Down Expand Up @@ -222,90 +82,6 @@ PyMjcbActBias(const raw::MjModel* m, const raw::MjData* d, int id) {
MjWrapperLookup(m), MjWrapperLookup(d), id);
}

// If the Python object is a ctypes function pointer, returns the corresponding
// C function pointer. Otherwise, returns a null pointer.
template <typename FuncPtr>
static FuncPtr GetCFuncPtr(py::handle h) {
struct CTypes { PyObject* cfuncptr; PyObject* cast; PyObject* c_void_p; };
static const CTypes ctypes = []() -> CTypes {
try {
auto m = py::module_::import("ctypes");
PyObject* cfuncptr = m.attr("_CFuncPtr").ptr();
PyObject* cast = m.attr("cast").ptr();
PyObject* c_void_p = m.attr("c_void_p").ptr();
Py_XINCREF(cfuncptr);
Py_XINCREF(cast);
Py_XINCREF(c_void_p);
return {cfuncptr, cast, c_void_p};
} catch (const py::error_already_set&) {
return {nullptr, nullptr, nullptr};
}
}();

if (!ctypes.cfuncptr) {
throw UnexpectedError("cannot find `ctypes._CFuncPtr`");
}
const int is_cfuncptr = PyObject_IsInstance(h.ptr(), ctypes.cfuncptr);
if (is_cfuncptr == -1) {
throw py::error_already_set();
} else if (is_cfuncptr) {
if (!ctypes.cast) {
throw UnexpectedError("cannot find `ctypes.cast`");
}
if (!ctypes.c_void_p) {
throw UnexpectedError("cannot find `ctypes.c_void_p`");
}
const uintptr_t func_address =
py::handle(ctypes.cast)(h, py::handle(ctypes.c_void_p))
.attr("value")
.template cast<std::uintptr_t>();
return reinterpret_cast<FuncPtr>(func_address);
} else {
return nullptr;
}
}

static bool IsCallable(py::handle h) {
static PyObject* const is_callable = []() -> PyObject* {
try{
PyObject* o = py::eval("callable").ptr();
Py_XINCREF(o);
return o;
} catch (const py::error_already_set&) {
return nullptr;
}
}();

if (!is_callable) {
throw UnexpectedError("cannot find `callable`");
}

return py::handle(is_callable)(h).cast<bool>();
}

template <typename CFuncPtr>
void SetCallback(py::handle h, CFuncPtr py_trampoline,
PyObject** py_callback, CFuncPtr* mj_callback) {
CFuncPtr cfuncptr = GetCFuncPtr<CFuncPtr>(h);
if (h.is_none()) {
Py_XDECREF(*py_callback);
*py_callback = nullptr;
*mj_callback = nullptr;
} else if (cfuncptr) {
Py_XDECREF(*py_callback);
Py_INCREF(h.ptr());
*py_callback = h.ptr();
*mj_callback = cfuncptr;
} else if (IsCallable(h)) {
Py_XDECREF(*py_callback);
Py_INCREF(h.ptr());
*py_callback = h.ptr();
*mj_callback = py_trampoline;
} else {
throw py::type_error("callback is not an Optional[Callable]");
}
}

py::object GetCallback(PyObject* py_callback) {
if (!py_callback) {
return py::none();
Expand Down
Loading

0 comments on commit 1fd9e0e

Please sign in to comment.