Skip to content

Commit

Permalink
Add capture_observation
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Sep 12, 2024
1 parent 16dd73f commit 88526d0
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions lerobot/common/robot_devices/robots/stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,18 @@ def run_calibration(self):
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
# TODO(aliberts): return proper types (ndarrays instead of torch.Tensors)
# TODO(aliberts): return ndarrays instead of torch.Tensors
if self.teleop is None:
self.teleop = GamePadTeleop(robot_instance=False)
self.teleop.startup(robot=self)

before_read_t = time.perf_counter()
self.teleop.do_motion(robot=self)
state = self._get_state()
action = self.teleop.gamepad_controller.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t

before_write_t = time.perf_counter()
self.teleop.do_motion(robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t

Expand Down Expand Up @@ -126,7 +126,33 @@ def _get_state(self) -> dict:
"base_theta.vel": status["base"]["theta_vel"],
}

def capture_observation(self): ...
def capture_observation(self):
# TODO(aliberts): return ndarrays instead of torch.Tensors
before_read_t = time.perf_counter()
state = self._get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t

if self.state_keys is None:
self.state_keys = list(state)

state = torch.as_tensor(list(state.values()))

# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t

# Populate output dictionnaries
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]

return obs_dict

def send_action(self, action): ...

Expand Down

0 comments on commit 88526d0

Please sign in to comment.