Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lift the requirement of human fingering with RP1M #24

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions robopianist/suite/tasks/piano_with_shadow_hands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List, Optional, Sequence, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment
from dm_control import mjcf
from dm_control.composer import variation as base_variation
from dm_control.composer.observation import observable
Expand Down Expand Up @@ -134,6 +135,11 @@ def _set_rewards(self) -> None:
)
if not self._disable_fingering_reward:
self._reward_fn.add("fingering_reward", self._compute_fingering_reward)
else:
# use OT based fingering
print('Fingering is unavailable. OT fingering reward is used.')
self._reward_fn.add("ot_fingering_reward", self._compute_ot_fingering_reward)

if not self._disable_forearm_reward:
self._reward_fn.add("forearm_reward", self._compute_forearm_reward)

Expand Down Expand Up @@ -324,6 +330,44 @@ def _distance_finger_to_key(
)
return float(np.mean(rews))

def _compute_ot_fingering_reward(self, physics: mjcf.Physics) -> float:
""" OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """
# calcuate fingertip positions
fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites]
fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites]

# calcuate the positions of piano keys to press.
keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press
# if no key is pressed
if keys_to_press.shape[0] == 0:
return 1.

# calculate key pos
key_pos = []
for key in keys_to_press:
key_geom = self.piano.keys[key].geom[0]
key_geom_pos = physics.bind(key_geom).xpos.copy()
key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2]
key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0]
key_pos.append(key_geom_pos.copy())

# calcualte the distance between keys and fingers
dist = np.full((len(fingertip_pos), len(key_pos)), 100.)
for i, finger in enumerate(fingertip_pos):
for j, key in enumerate(key_pos):
dist[i, j] = np.linalg.norm(key - finger)

# calculate the shortest distance
row_ind, col_ind = linear_sum_assignment(dist)
dist = dist[row_ind, col_ind]
rews = tolerance(
dist,
bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY),
margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10),
sigmoid="gaussian",
)
return float(np.mean(rews))

def _update_goal_state(self) -> None:
# Observable callables get called after `after_step` but before
# `should_terminate_episode`. Since we increment `self._t_idx` in `after_step`,
Expand Down
Loading