diff --git a/robosuite/controllers/controller_factory.py b/robosuite/controllers/controller_factory.py index d72e3475e1..700968717b 100644 --- a/robosuite/controllers/controller_factory.py +++ b/robosuite/controllers/controller_factory.py @@ -116,31 +116,26 @@ def controller_factory(name, params): if name == "OSC_POSE": ori_interpolator = None if interpolator is not None: - interpolator.dim = 3 # EE control uses dim 3 for pos and ori each + interpolator.set_states(dim=3) # EE control uses dim 3 for pos and ori each ori_interpolator = deepcopy(interpolator) - ori_interpolator.ori_interpolate = "euler" + ori_interpolator.set_states(ori="euler") params["control_ori"] = True return OperationalSpaceController(interpolator_pos=interpolator, interpolator_ori=ori_interpolator, **params) if name == "OSC_POSITION": if interpolator is not None: - interpolator.dim = 3 # EE control uses dim 3 for pos + interpolator.set_states(dim=3) # EE control uses dim 3 for pos params["control_ori"] = False return OperationalSpaceController(interpolator_pos=interpolator, **params) if name == "IK_POSE": ori_interpolator = None if interpolator is not None: - interpolator.dim = 3 # EE IK control uses dim 3 for pos and dim 4 for ori - interpolator.start = np.zeros(3) - interpolator.goal = np.zeros(3) + interpolator.set_states(dim=3) # EE IK control uses dim 3 for pos and dim 4 for ori ori_interpolator = deepcopy(interpolator) - ori_interpolator.dim = 4 - ori_interpolator.start = np.array((0, 0, 0, 1)) - ori_interpolator.goal = np.array((0, 0, 0, 1)) - ori_interpolator.ori_interpolate = "quat" - + ori_interpolator.set_states(dim=4, ori="quat") + # Import pybullet server if necessary global pybullet_server from .ik import InverseKinematicsController diff --git a/robosuite/controllers/interpolators/linear_interpolator.py b/robosuite/controllers/interpolators/linear_interpolator.py index 17111838d0..d23e9a42b0 100644 --- a/robosuite/controllers/interpolators/linear_interpolator.py +++ b/robosuite/controllers/interpolators/linear_interpolator.py @@ -35,15 +35,34 @@ def __init__(self, use_delta_goal=False, ori_interpolate=None, ): - - self.dim = ndim # Number of dimensions to interpolate - self.order = 1 # Order of the interpolator (1 = linear) - self.step = 0 # Current step of the interpolator + self.dim = ndim # Number of dimensions to interpolate + self.ori_interpolate = ori_interpolate # Whether this is interpolating orientation or not + self.order = 1 # Order of the interpolator (1 = linear) + self.step = 0 # Current step of the interpolator self.total_steps = \ - np.ceil(ramp_ratio * controller_freq / policy_freq) # Total num steps per interpolator action - self.use_delta_goal = use_delta_goal # Whether to use delta or absolute goals (currently - # not implemented yet- TODO) - self.ori_interpolate = ori_interpolate # Whether this is interpolating orientation or not + np.ceil(ramp_ratio * controller_freq / policy_freq) # Total num steps per interpolator action + self.use_delta_goal = use_delta_goal # Whether to use delta or absolute goals (currently + # not implemented yet- TODO) + self.set_states(dim=ndim, ori=ori_interpolate) + + def set_states(self, dim=None, ori=None): + """ + Updates self.dim and self.ori_interpolate. + + Initializes self.start and self.goal with correct dimensions. + + Args: + ndim (None or int): Number of dimensions to interpolate + + ori_interpolate (None or str): If set, assumes that we are interpolating angles (orientation) + Specified string determines assumed type of input: + + `'euler'`: Euler orientation inputs + `'quat'`: Quaternion inputs + """ + # Update self.dim and self.ori_interpolate + self.dim = dim if dim is not None else self.dim + self.ori_interpolate = ori if ori is not None else self.ori_interpolate # Set start and goal states if self.ori_interpolate is not None: @@ -52,7 +71,7 @@ def __init__(self, else: # quaternions self.start = np.array((0, 0, 0, 1)) else: - self.start = np.zeros(ndim) + self.start = np.zeros(self.dim) self.goal = np.array(self.start) def set_goal(self, goal): diff --git a/robosuite/controllers/joint_pos.py b/robosuite/controllers/joint_pos.py index e0a2a52c75..5bfa4b4d6e 100644 --- a/robosuite/controllers/joint_pos.py +++ b/robosuite/controllers/joint_pos.py @@ -226,7 +226,7 @@ def run_controller(self): if self.interpolator is not None: # Linear case if self.interpolator.order == 1: - desired_qpos = self.interpolator.get_interpolated_goal(self.joint_pos) + desired_qpos = self.interpolator.get_interpolated_goal() else: # Nonlinear case not currently supported pass diff --git a/robosuite/controllers/joint_tor.py b/robosuite/controllers/joint_tor.py index cf1e9fc998..9c50c2bfd4 100644 --- a/robosuite/controllers/joint_tor.py +++ b/robosuite/controllers/joint_tor.py @@ -139,7 +139,7 @@ def run_controller(self): if self.interpolator is not None: # Linear case if self.interpolator.order == 1: - self.current_torque = self.interpolator.get_interpolated_goal(self.current_torque) + self.current_torque = self.interpolator.get_interpolated_goal() else: # Nonlinear case not currently supported pass diff --git a/robosuite/controllers/joint_vel.py b/robosuite/controllers/joint_vel.py index 528562a9f4..0948828b9b 100644 --- a/robosuite/controllers/joint_vel.py +++ b/robosuite/controllers/joint_vel.py @@ -161,7 +161,7 @@ def run_controller(self): if self.interpolator is not None: if self.interpolator.order == 1: # Linear case - self.current_vel = self.interpolator.get_interpolated_goal(self.current_vel) + self.current_vel = self.interpolator.get_interpolated_goal() else: # Nonlinear case not currently supported pass