Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ball joints #56

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions stac_mjx/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,37 @@
import imageio
from tqdm import tqdm

# Root position (3) + quaternion (7) in qpos
_ROOT_QPOS_LB = -jp.inf * jp.ones(7)
_ROOT_QPOS_UB = jp.inf * jp.ones(7)
# Root = position (3) + quaternion (4)
_ROOT_QPOS_LB = jp.concatenate([-jp.inf * jp.ones(3), -1.0 * jp.ones(4)])
_ROOT_QPOS_UB = jp.concatenate([jp.inf * jp.ones(3), 1.0 * jp.ones(4)])

# mujoco jnt_type enums: https://mujoco.readthedocs.io/en/latest/APIreference/APItypes.html#mjtjoint
_MUJOCO_JOINT_TYPE_DIMS = {
mujoco.mjtJoint.mjJNT_FREE: 7,
mujoco.mjtJoint.mjJNT_BALL: 4,
mujoco.mjtJoint.mjJNT_SLIDE: 1,
mujoco.mjtJoint.mjJNT_HINGE: 1,
}


def _align_joint_dims(types, ranges, names):
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved
"""Creates bounds and joint names aligned with qpos dimensions."""
lb = []
ub = []
part_names = []
for type, range, name in zip(types, ranges, names):
jf514 marked this conversation as resolved.
Show resolved Hide resolved
dims = _MUJOCO_JOINT_TYPE_DIMS[type]
# Set inf bounds for freejoint
if type == mujoco.mjtJoint.mjJNT_FREE:
lb.append(_ROOT_QPOS_LB)
ub.append(_ROOT_QPOS_UB)
part_names += [name] * dims
else:
lb.append(range[0] * jp.ones(dims))
ub.append(range[1] * jp.ones(dims))
part_names += [name] * dims

# Prepend this to list of part names for one-to-one correspondence with qpos
_ROOT_NAMES = ["root"] * 6
return jp.minimum(jp.concatenate(lb), 0.0), jp.concatenate(ub), part_names
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved


class STAC:
Expand All @@ -51,37 +76,38 @@ def __init__(
self._kp_names = kp_names
self._root = mjcf.from_path(xml_path)
(
mj_model,
self._mj_model,
self._body_site_idxs,
self._is_regularized,
self._part_names,
self._body_names,
) = self._create_body_sites(self._root)

self._body_names = [
self._mj_model.body(i).name for i in range(self._mj_model.nbody)
]

joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]

# Set up bounds and part_names based on joint ranges, taking into account the dimensionality of parameters
self._lb, self._ub, self._part_names = _align_joint_dims(
self._mj_model.jnt_type, self._mj_model.jnt_range, joint_names
)

self._indiv_parts = self.part_opt_setup()

self._trunk_kps = jp.array(
[n in self.model_cfg["TRUNK_OPTIMIZATION_KEYPOINTS"] for n in kp_names],
)

mj_model.opt.solver = {
self._mj_model.opt.solver = {
"cg": mujoco.mjtSolver.mjSOL_CG,
"newton": mujoco.mjtSolver.mjSOL_NEWTON,
}[stac_cfg.mujoco.solver.lower()]

mj_model.opt.iterations = stac_cfg.mujoco.iterations
mj_model.opt.ls_iterations = stac_cfg.mujoco.ls_iterations
self._mj_model.opt.iterations = stac_cfg.mujoco.iterations
self._mj_model.opt.ls_iterations = stac_cfg.mujoco.ls_iterations

# Runs faster on GPU with this
mj_model.opt.jacobian = 0 # dense

self._mj_model = mj_model

# Set joint bounds
self._lb = jp.minimum(
jp.concatenate([_ROOT_QPOS_LB, self._mj_model.jnt_range[1:][:, 0]]),
0.0,
)
self._ub = jp.concatenate([_ROOT_QPOS_UB, self._mj_model.jnt_range[1:][:, 1]])
self._mj_model.opt.jacobian = 0 # dense

def part_opt_setup(self):
"""Set up the lists of indices for part optimization.
Expand Down Expand Up @@ -142,9 +168,6 @@ def _create_body_sites(self, root: mjcf.Element):
key: int(axis.convert_key_item(key))
for key in self.model_cfg["KEYPOINT_MODEL_PAIRS"].keys()
}
part_names = _ROOT_NAMES + physics.named.data.qpos.axes.row.names

body_names = physics.named.data.xpos.axes.row.names

# Define which offsets to regularize
is_regularized = []
Expand All @@ -160,8 +183,6 @@ def _create_body_sites(self, root: mjcf.Element):
physics.model.ptr,
jp.array(list(site_index_map.values())),
is_regularized,
part_names,
body_names,
)

def _chunk_kp_data(self, kp_data):
Expand Down
71 changes: 70 additions & 1 deletion tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from stac_mjx import main
from stac_mjx import utils
from stac_mjx.controller import STAC
from stac_mjx.controller import STAC, _align_joint_dims
from mujoco import _structs

_BASE_PATH = Path.cwd()
Expand All @@ -25,3 +25,72 @@ def test_init_stac(mocap_nwb, stac_config, rodent_config):
assert stac.model_cfg == model_cfg
assert stac._kp_names == sorted_kp_names
assert isinstance(stac._mj_model, _structs.MjModel)


def test_align_joint_dims():
from jax import numpy as jp
import mujoco

joint_types = [
mujoco.mjtJoint.mjJNT_FREE,
mujoco.mjtJoint.mjJNT_HINGE,
mujoco.mjtJoint.mjJNT_BALL,
mujoco.mjtJoint.mjJNT_SLIDE,
]
ranges = [[0.0, 0.0], [-0.1, 0.1], [0.0, 1.0], [-0.5, 0.5]]
names = ["root", "hingejoint", "balljoint", "slidejoint"]
lb, ub, part_names = _align_joint_dims(joint_types, ranges, names)
print(lb)

true_lb = jp.array(
[
-jp.inf,
-jp.inf,
-jp.inf,
-1.0,
-1.0,
-1.0,
-1.0,
-0.1,
0.0,
0.0,
0.0,
0.0,
-0.5,
]
)

true_ub = jp.array(
[
jp.inf,
jp.inf,
jp.inf,
1.0,
1.0,
1.0,
1.0,
0.1,
1.0,
1.0,
1.0,
1.0,
0.5,
]
)
assert jp.array_equal(lb, true_lb)
assert jp.array_equal(ub, true_ub)
assert part_names == [
"root",
"root",
"root",
"root",
"root",
"root",
"root",
"hingejoint",
"balljoint",
"balljoint",
"balljoint",
"balljoint",
"slidejoint",
]
Loading