Skip to content

Commit

Permalink
Fix interpolator dim bug (#181)
Browse files Browse the repository at this point in the history
* reset states at dim change

* correct call of interpolator method

* merge setting ori and dim

* add doc string
  • Loading branch information
hermanjakobsen authored Jan 29, 2021
1 parent c2a9938 commit f123b78
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 23 deletions.
17 changes: 6 additions & 11 deletions robosuite/controllers/controller_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,26 @@ def controller_factory(name, params):
if name == "OSC_POSE":
ori_interpolator = None
if interpolator is not None:
interpolator.dim = 3 # EE control uses dim 3 for pos and ori each
interpolator.set_states(dim=3) # EE control uses dim 3 for pos and ori each
ori_interpolator = deepcopy(interpolator)
ori_interpolator.ori_interpolate = "euler"
ori_interpolator.set_states(ori="euler")
params["control_ori"] = True
return OperationalSpaceController(interpolator_pos=interpolator,
interpolator_ori=ori_interpolator, **params)

if name == "OSC_POSITION":
if interpolator is not None:
interpolator.dim = 3 # EE control uses dim 3 for pos
interpolator.set_states(dim=3) # EE control uses dim 3 for pos
params["control_ori"] = False
return OperationalSpaceController(interpolator_pos=interpolator, **params)

if name == "IK_POSE":
ori_interpolator = None
if interpolator is not None:
interpolator.dim = 3 # EE IK control uses dim 3 for pos and dim 4 for ori
interpolator.start = np.zeros(3)
interpolator.goal = np.zeros(3)
interpolator.set_states(dim=3) # EE IK control uses dim 3 for pos and dim 4 for ori
ori_interpolator = deepcopy(interpolator)
ori_interpolator.dim = 4
ori_interpolator.start = np.array((0, 0, 0, 1))
ori_interpolator.goal = np.array((0, 0, 0, 1))
ori_interpolator.ori_interpolate = "quat"

ori_interpolator.set_states(dim=4, ori="quat")

# Import pybullet server if necessary
global pybullet_server
from .ik import InverseKinematicsController
Expand Down
37 changes: 28 additions & 9 deletions robosuite/controllers/interpolators/linear_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,34 @@ def __init__(self,
use_delta_goal=False,
ori_interpolate=None,
):

self.dim = ndim # Number of dimensions to interpolate
self.order = 1 # Order of the interpolator (1 = linear)
self.step = 0 # Current step of the interpolator
self.dim = ndim # Number of dimensions to interpolate
self.ori_interpolate = ori_interpolate # Whether this is interpolating orientation or not
self.order = 1 # Order of the interpolator (1 = linear)
self.step = 0 # Current step of the interpolator
self.total_steps = \
np.ceil(ramp_ratio * controller_freq / policy_freq) # Total num steps per interpolator action
self.use_delta_goal = use_delta_goal # Whether to use delta or absolute goals (currently
# not implemented yet- TODO)
self.ori_interpolate = ori_interpolate # Whether this is interpolating orientation or not
np.ceil(ramp_ratio * controller_freq / policy_freq) # Total num steps per interpolator action
self.use_delta_goal = use_delta_goal # Whether to use delta or absolute goals (currently
# not implemented yet- TODO)
self.set_states(dim=ndim, ori=ori_interpolate)

def set_states(self, dim=None, ori=None):
"""
Updates self.dim and self.ori_interpolate.
Initializes self.start and self.goal with correct dimensions.
Args:
ndim (None or int): Number of dimensions to interpolate
ori_interpolate (None or str): If set, assumes that we are interpolating angles (orientation)
Specified string determines assumed type of input:
`'euler'`: Euler orientation inputs
`'quat'`: Quaternion inputs
"""
# Update self.dim and self.ori_interpolate
self.dim = dim if dim is not None else self.dim
self.ori_interpolate = ori if ori is not None else self.ori_interpolate

# Set start and goal states
if self.ori_interpolate is not None:
Expand All @@ -52,7 +71,7 @@ def __init__(self,
else: # quaternions
self.start = np.array((0, 0, 0, 1))
else:
self.start = np.zeros(ndim)
self.start = np.zeros(self.dim)
self.goal = np.array(self.start)

def set_goal(self, goal):
Expand Down
2 changes: 1 addition & 1 deletion robosuite/controllers/joint_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def run_controller(self):
if self.interpolator is not None:
# Linear case
if self.interpolator.order == 1:
desired_qpos = self.interpolator.get_interpolated_goal(self.joint_pos)
desired_qpos = self.interpolator.get_interpolated_goal()
else:
# Nonlinear case not currently supported
pass
Expand Down
2 changes: 1 addition & 1 deletion robosuite/controllers/joint_tor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def run_controller(self):
if self.interpolator is not None:
# Linear case
if self.interpolator.order == 1:
self.current_torque = self.interpolator.get_interpolated_goal(self.current_torque)
self.current_torque = self.interpolator.get_interpolated_goal()
else:
# Nonlinear case not currently supported
pass
Expand Down
2 changes: 1 addition & 1 deletion robosuite/controllers/joint_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def run_controller(self):
if self.interpolator is not None:
if self.interpolator.order == 1:
# Linear case
self.current_vel = self.interpolator.get_interpolated_goal(self.current_vel)
self.current_vel = self.interpolator.get_interpolated_goal()
else:
# Nonlinear case not currently supported
pass
Expand Down

0 comments on commit f123b78

Please sign in to comment.