Skip to content

Commit

Permalink
Remove StretchRobot import
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Sep 12, 2024
1 parent 12a2f2c commit 16dd73f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
3 changes: 2 additions & 1 deletion lerobot/common/robot_devices/robots/stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@dataclass
class StretchRobotConfig:
robot_type: str | None = None
robot_type: str | None = "stretch"
cameras: dict[str, Camera] = field(default_factory=lambda: {})
# TODO(aliberts): add feature with max_relative target
# TODO(aliberts): add comment on max_relative target
Expand All @@ -43,6 +43,7 @@ def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)

self.robot_type = self.config.robot_type
self.cameras = self.config.cameras
self.is_connected = False
self.teleop = None
Expand Down
2 changes: 2 additions & 0 deletions lerobot/common/robot_devices/robots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def get_arm_id(name, arm_type):


class Robot(Protocol):
robot_type: str

def connect(self): ...
def run_calibration(self): ...
def teleop_step(self, record_data=False): ...
Expand Down
7 changes: 3 additions & 4 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.stretch import StretchRobot
from lerobot.common.robot_devices.robots.utils import Robot, get_arm_id
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
Expand Down Expand Up @@ -177,7 +176,7 @@ def none_or_int(value):
return int(value)


def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
Expand All @@ -197,7 +196,7 @@ def log_dt(shortname, dt_val_s):
log_dt("dt", dt_s)

# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not isinstance(robot, StretchRobot):
if not robot.robot_type.startswith("stretch"):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
Expand Down Expand Up @@ -252,7 +251,7 @@ def has_method(_object: object, method_name: str):
@safe_disconnect
def calibrate(robot: Robot, arms: list[str] | None):
# TODO(aliberts): move this code in robots' classes
if isinstance(robot, StretchRobot):
if robot.robot_type.startswith("stretch"):
if not robot.is_connected:
robot.connect()
if not robot.is_homed():
Expand Down

0 comments on commit 16dd73f

Please sign in to comment.