Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gripper penalty and blocking environment #65

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@ SERL provides a set of libraries, env wrappers, and examples to train RL policie
- [Contribution](#contribution)
- [Citation](#citation)

## Major bug fix
## Major updates
#### June 24, 2024
For people who use SERL for tasks involving controlling the gripper (e.g.,pick up objects), we strong recommend adding a small penalty to the gripper action change, as it will greatly improves the training speed.
For detail, please refer to: [PR #65](https://github.com/rail-berkeley/serl/pull/65).


Further, we also recommend providing interventions online during training in addition to loading the offline demos. If you have a Franka robot and SpaceMouse, this can be as easy as just touching the SpaceMouse during training.

#### April 25, 2024
We fixed a major issue in the intervention action frame. See release [v0.1.1](https://github.com/rail-berkeley/serl/releases/tag/v0.1.1) Please update your code with the main branch.

## Installation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BinEnvConfig(DefaultEnvConfig):
]
)
RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0])
REWARD_THRESHOLD: np.ndarray = np.array([0.01, 0.01, 0.01, 0.2, 0.2, 0.2])
REWARD_THRESHOLD: np.ndarray = np.zeros(6)
ACTION_SCALE = np.array([0.05, 0.1, 1])
RANDOM_RESET = False
RANDOM_XY_RANGE = 0.1
Expand Down
2 changes: 1 addition & 1 deletion serl_robot_infra/franka_env/envs/cable_env/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CableEnvConfig(DefaultEnvConfig):
]
)
RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0])
REWARD_THRESHOLD: np.ndarray = np.array([0.01, 0.01, 0.01, 0.2, 0.2, 0.2])
REWARD_THRESHOLD: np.ndarray = np.zeros(6)
ACTION_SCALE = np.array([0.05, 0.3, 1])
RANDOM_RESET = True
RANDOM_XY_RANGE = 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def go_to_rest(self, joint_reset=False):
Move to the rest position defined in base class.
Add a small z offset before going to rest to avoid collision with object.
"""
self._send_gripper_command(-1)
self._update_currpos()
self._send_pos_command(self.currpos)
time.sleep(0.5)
Expand Down
40 changes: 30 additions & 10 deletions serl_robot_infra/franka_env/envs/franka_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class DefaultEnvConfig:
ABS_POSE_LIMIT_LOW = np.zeros((6,))
COMPLIANCE_PARAM: Dict[str, float] = {}
PRECISION_PARAM: Dict[str, float] = {}
BINARY_GRIPPER_THREASHOLD: float = 0.5
GRIPPER_PENALTY: float = 0.1


##############################################################################
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
self.currjacobian = np.zeros((6, 7))

self.curr_gripper_pos = 0
self.gripper_binary_state = 0 # 0 for open, 1 for closed
self.lastsent = time.time()
self.randomreset = config.RANDOM_RESET
self.random_xy_range = config.RANDOM_XY_RANGE
Expand Down Expand Up @@ -201,7 +204,7 @@ def step(self, action: np.ndarray) -> tuple:

gripper_action = action[6] * self.action_scale[2]

self._send_gripper_command(gripper_action)
gripper_action_effective = self._send_gripper_command(gripper_action)
self._send_pos_command(self.clip_safety_box(self.nextpos))

self.curr_path_length += 1
Expand All @@ -210,11 +213,11 @@ def step(self, action: np.ndarray) -> tuple:

self._update_currpos()
ob = self._get_obs()
reward = self.compute_reward(ob)
done = self.curr_path_length >= self.max_episode_length or reward
return ob, int(reward), done, False, {}
reward = self.compute_reward(ob, gripper_action_effective)
done = self.curr_path_length >= self.max_episode_length or reward == 1
return ob, reward, done, False, {}

def compute_reward(self, obs) -> bool:
def compute_reward(self, obs, gripper_action_effective) -> bool:
"""We are using a sparse reward function."""
current_pose = obs["state"]["tcp_pose"]
# convert from quat to euler first
Expand All @@ -223,10 +226,15 @@ def compute_reward(self, obs) -> bool:
current_pose = np.hstack([current_pose[:3], euler_angles])
delta = np.abs(current_pose - self._TARGET_POSE)
if np.all(delta < self._REWARD_THRESHOLD):
return True
reward = 1
else:
# print(f'Goal not reached, the difference is {delta}, the desired threshold is {_REWARD_THRESHOLD}')
return False
reward = 0

if gripper_action_effective:
jianlanluo marked this conversation as resolved.
Show resolved Hide resolved
reward -= self.config.GRIPPER_PENALTY

return reward

def crop_image(self, name, image) -> np.ndarray:
"""Crop realsense images to be a square."""
Expand Down Expand Up @@ -379,12 +387,24 @@ def _send_pos_command(self, pos: np.ndarray):
def _send_gripper_command(self, pos: float, mode="binary"):
"""Internal function to send gripper command to the robot."""
if mode == "binary":
if (pos >= -1) and (pos <= -0.9): # close gripper
if (
pos <= -self.config.BINARY_GRIPPER_THREASHOLD
and self.gripper_binary_state == 0
): # close gripper
requests.post(self.url + "close_gripper")
elif (pos >= 0.9) and (pos <= 1): # open gripper
time.sleep(0.6)
self.gripper_binary_state = 1
return True
elif (
pos >= self.config.BINARY_GRIPPER_THREASHOLD
and self.gripper_binary_state == 1
): # open gripper
requests.post(self.url + "open_gripper")
time.sleep(0.6)
self.gripper_binary_state = 0
return True
else: # do nothing to the gripper
return
return False
elif mode == "continuous":
raise NotImplementedError("Continuous gripper control is optional")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def go_to_rest(self, joint_reset=False):
Move to the rest position defined in base class.
Add a small z offset before going to rest to avoid collision with object.
"""
self._send_gripper_command(-1)
self._update_currpos()
self._send_pos_command(self.currpos)
time.sleep(0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def go_to_rest(self, joint_reset=False):
Move to the rest position defined in base class.
Add a small z offset before going to rest to avoid collision with object.
"""
self._send_gripper_command(-1)
self._update_currpos()
self._send_pos_command(self.currpos)
time.sleep(0.5)
Expand Down
15 changes: 9 additions & 6 deletions serl_robot_infra/franka_env/envs/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def compute_reward(self, obs):

def step(self, action):
obs, rew, done, truncated, info = self.env.step(action)
rew = self.compute_reward(self.env.get_front_cam_obs())
done = done or rew
success = self.compute_reward(self.env.get_front_cam_obs())
rew += success
done = done or success
return obs, rew, done, truncated, info


Expand All @@ -72,8 +73,9 @@ def compute_reward(self, obs):

def step(self, action):
obs, rew, done, truncated, info = self.env.step(action)
rew = self.compute_reward(self.env.get_front_cam_obs())
done = done or rew
success = self.compute_reward(self.env.get_front_cam_obs())
rew += success
done = done or success
return obs, rew, done, truncated, info


Expand All @@ -94,8 +96,9 @@ def compute_reward(self, obs):

def step(self, action):
obs, rew, done, truncated, info = self.env.step(action)
rew = self.compute_reward(obs)
done = done or rew
success = self.compute_reward(obs)
rew += success
done = done or success
return obs, rew, done, truncated, info


Expand Down
Loading