Skip to content

Commit

Permalink
WIP Python bindings for Simulate tested on Linux only
Browse files Browse the repository at this point in the history
  • Loading branch information
aftersomemath committed Aug 3, 2022
1 parent 7fb1ffe commit aaa0540
Show file tree
Hide file tree
Showing 8 changed files with 517 additions and 6 deletions.
53 changes: 51 additions & 2 deletions python/mujoco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ add_compile_options("${MUJOCO_HARDEN_COMPILE_OPTIONS}")
add_link_options("${MUJOCO_HARDEN_LINK_OPTIONS}")

find_package(Python3 COMPONENTS Interpreter Development)
find_package(glfw3 3.3 REQUIRED)

include(FindOrFetch)

Expand Down Expand Up @@ -115,6 +116,40 @@ if(NOT TARGET mujoco)
endif()
endif()

# ==================== MJSIMULATE LIBRARY ==========================================
if(NOT TARGET mjsimulate)
find_library(MJSIMULATE_LIBRARY mjsimulate mjsimulate HINTS ${MUJOCO_LIBRARY_DIR} REQUIRED)
find_path(MJSIMULATE_INCLUDE mujoco/simulate.h mujoco/uitools.h mujoco/array_safety.h HINTS ${MUJOCO_INCLUDE_DIR} REQUIRED)
message("MuJoCo Simulate is at ${MJSIMULATE_LIBRARY}")
message("MuJoCo Simulate headers are at ${MJSIMULATE_INCLUDE}")
add_library(mjsimulate SHARED IMPORTED)
if(WIN32)
set_target_properties(mjsimulate PROPERTIES IMPORTED_IMPLIB "${MJSIMULATE_LIBRARY}")
else()
set_target_properties(mjsimulate PROPERTIES IMPORTED_LOCATION "${MJSIMULATE_LIBRARY}")
endif()
target_include_directories(mjsimulate INTERFACE "${MJSIMULATE_INCLUDE}")
if(APPLE)
execute_process(
COMMAND otool -XD ${MJSIMULATE_LIBRARY}
COMMAND head -n 1
COMMAND xargs dirname
COMMAND xargs echo -n
OUTPUT_VARIABLE MJSIMULATE_INSTALL_NAME_DIR
)
set_target_properties(mjsimulate PROPERTIES INSTALL_NAME_DIR "${MJSIMULATE_INSTALL_NAME_DIR}")
elseif(UNIX)
execute_process(
COMMAND objdump -p ${MJSIMULATE_LIBRARY}
COMMAND grep SONAME
COMMAND grep -Po [^\\s]+$
COMMAND xargs echo -n
OUTPUT_VARIABLE MJSIMULATE_SONAME
)
set_target_properties(mjsimulate PROPERTIES IMPORTED_SONAME "${MJSIMULATE_SONAME}")
endif()
endif()

# ==================== ABSEIL ==================================================
set(MUJOCO_PYTHON_ABSL_TARGETS absl::core_headers absl::flat_hash_map absl::span)
findorfetch(
Expand Down Expand Up @@ -231,7 +266,7 @@ target_link_libraries(errors_header INTERFACE crossplatform func_wrap mujoco)
add_library(raw INTERFACE)
target_sources(raw INTERFACE raw.h)
set_target_properties(raw PROPERTIES PUBLIC_HEADER raw.h)
target_link_libraries(raw INTERFACE mujoco)
target_link_libraries(raw INTERFACE mujoco mjsimulate)

add_library(structs_header INTERFACE)
target_sources(
Expand Down Expand Up @@ -309,7 +344,7 @@ target_link_libraries(
)

mujoco_pybind11_module(_constants constants.cc)
target_link_libraries(_constants PRIVATE mujoco)
target_link_libraries(_constants PRIVATE mujoco mjsimulate)

mujoco_pybind11_module(_enums enums.cc)
target_link_libraries(
Expand Down Expand Up @@ -370,6 +405,16 @@ target_link_libraries(
structs_header
)

mujoco_pybind11_module(_simulate simulate.cc)
target_link_libraries(
_simulate
PRIVATE mjsimulate
mujoco
raw
glfw
structs_header)
target_link_options(_simulate PRIVATE -Wl,-no-as-needed)

set(LIBRARIES_FOR_WHEEL
"$<TARGET_FILE:_callbacks>"
"$<TARGET_FILE:_constants>"
Expand All @@ -378,8 +423,10 @@ set(LIBRARIES_FOR_WHEEL
"$<TARGET_FILE:_functions>"
"$<TARGET_FILE:_render>"
"$<TARGET_FILE:_rollout>"
"$<TARGET_FILE:_simulate>"
"$<TARGET_FILE:_structs>"
"$<TARGET_FILE:mujoco>"
"$<TARGET_FILE:mjsimulate>"
)

if(MUJOCO_PYTHON_MAKE_WHEEL)
Expand All @@ -405,7 +452,9 @@ if(MUJOCO_PYTHON_MAKE_WHEEL)
_functions
_render
_rollout
_simulate
_structs
mujoco
mjsimulate
)
endif()
4 changes: 4 additions & 0 deletions python/mujoco/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <utility>
#include <vector>
#include <mujoco/mjmodel.h>
#include <mujoco/simulate.h>
#include <mujoco/mjvisualize.h>
#include <mujoco/mujoco.h>
#include <pybind11/cast.h>
Expand Down Expand Up @@ -78,6 +79,9 @@ PYBIND11_MODULE(_constants, pymodule) {
// from mujoco.h
X(mjVERSION_HEADER);

// from simulate.h
X(mujoco::Simulate::kMaxFilenameLength);

#undef X
pymodule.attr("mjDISABLESTRING") = MakeTuple(mjDISABLESTRING);
pymodule.attr("mjENABLESTRING") = MakeTuple(mjENABLESTRING);
Expand Down
4 changes: 4 additions & 0 deletions python/mujoco/raw.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <mujoco/mjmodel.h>
#include <mujoco/mjrender.h>
#include <mujoco/mjvisualize.h>
#include <mujoco/simulate.h>

// Type aliases for MuJoCo C structs to allow us refer to consistently refer
// to them under the "raw" namespace.
Expand Down Expand Up @@ -55,6 +56,9 @@ using MjvOption = ::mjvOption;
using MjvScene = ::mjvScene;
using MjvFigure = ::mjvFigure;

// From simulate.h
using Simulate = ::mujoco::Simulate;

} // namespace mujoco::raw

#endif // MUJOCO_PYTHON_RAW_H_
142 changes: 142 additions & 0 deletions python/mujoco/simulate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2022 DeepMind Technologies Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <array>
#include <cstdio>
#include <iostream>
#include <optional>
#include <sstream>
#include <string>

#include <mujoco/mujoco.h>
#include <mujoco/simulate.h>
#include "raw.h"
#include "structs.h"
#include <pybind11/buffer_info.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace mujoco::python {

namespace {

namespace py = ::pybind11;

const auto simulate_doc = R"(
Python wrapper for the Simulate class
)";

// We define SimulateWrapper here instead of in structs because
// we do not want to make _structs dependent on gflw

PYBIND11_MODULE(_simulate, pymodule) {
namespace py = ::pybind11;

py::class_<mujoco::raw::Simulate>(pymodule, "Simulate")
.def(py::init<>())
.def("renderloop",
[](mujoco::raw::Simulate& simulate) {
simulate.renderloop();
},
py::call_guard<py::gil_scoped_release>())
.def("load",
[](mujoco::raw::Simulate& simulate, std::string filename, const MjModelWrapper& m, MjDataWrapper& d, bool delete_old_m_d) {
const raw::MjModel* m_ptr = m.get();
raw::MjData* d_ptr = d.get();
simulate.load(filename.c_str(), (mjModel*)m_ptr, d_ptr, delete_old_m_d);
},
py::call_guard<py::gil_scoped_release>())
.def("applyposepertubations", &mujoco::raw::Simulate::applyposepertubations)
.def("applyforceperturbations", &mujoco::raw::Simulate::applyforceperturbations)

.def("lock", // TODO wrap mutex properly as as seperate pybind11 object?
[](mujoco::raw::Simulate& simulate) {
simulate.mtx.lock();
},
py::call_guard<py::gil_scoped_release>())
.def("unlock",
[](mujoco::raw::Simulate& simulate) {
simulate.mtx.unlock();
},
py::call_guard<py::gil_scoped_release>())
.def_readwrite("ctrlnoisestd", &mujoco::raw::Simulate::ctrlnoisestd)
.def_readwrite("ctrlnoiserate", &mujoco::raw::Simulate::ctrlnoiserate)
.def_readwrite("slow_down", &mujoco::raw::Simulate::slow_down)
.def_readwrite("speed_changed", &mujoco::raw::Simulate::speed_changed)
.def("getrefreshRate",
[](mujoco::raw::Simulate& simulate) {
return simulate.vmode.refreshRate;
})

.def_readwrite("busywait", &mujoco::raw::Simulate::busywait)
.def_readwrite("run", &mujoco::raw::Simulate::run)
//.def_readwrite("exitrequest", &mujoco::raw::Simulate::exitrequest)
.def("getexitrequest",
[](mujoco::raw::Simulate& simulate) {
return simulate.exitrequest.load();
}
)
.def("setexitrequest",
[](mujoco::raw::Simulate& simulate, bool exitrequest) {
simulate.exitrequest.store(exitrequest);
}
)
// .def_readwrite("uiloadrequest", &mujoco::raw::Simulate::uiloadrequest)
.def("getuiloadrequest",
[](mujoco::raw::Simulate& simulate) {
return simulate.uiloadrequest.load();
}
)
.def("setuiloadrequest",
[](mujoco::raw::Simulate& simulate, int uiloadrequest) {
simulate.uiloadrequest.store(uiloadrequest);
}
)
.def("uiloadrequest_fetch_sub",
[](mujoco::raw::Simulate& simulate, int arg) {
simulate.uiloadrequest.fetch_sub(arg);
}
)
// .def_readwrite("droploadrequest", &mujoco::raw::Simulate::droploadrequest)
.def("getdroploadrequest",
[](mujoco::raw::Simulate& simulate) {
return simulate.droploadrequest.load();
}
)
.def("setdroploadrequest",
[](mujoco::raw::Simulate& simulate, bool droploadrequest) {
simulate.droploadrequest.store(droploadrequest);
}
)
.def("getdropfilename",
[](mujoco::raw::Simulate& simulate) {
return (char*)simulate.dropfilename;
}
)
.def("getfilename",
[](mujoco::raw::Simulate& simulate) {
return (char*)simulate.filename;
}
)
.def("setloadError",
[](mujoco::raw::Simulate& simulate, std::string& loadError) {
strncpy(simulate.loadError, loadError.c_str(), simulate.kMaxFilenameLength);
}
);
}

} // namespace

}
Loading

0 comments on commit aaa0540

Please sign in to comment.