Skip to content

Commit

Permalink
Merge pull request #150 from NeLy-EPFL/head-stabilization-complex-ter…
Browse files Browse the repository at this point in the history
…rain

Head stabilization over complex terrain & fly-to-fly following using head stabilization
  • Loading branch information
sibocw authored May 1, 2024
2 parents 60b4ad3 + 653f091 commit d46c5d5
Show file tree
Hide file tree
Showing 18 changed files with 761 additions and 198 deletions.
2 changes: 1 addition & 1 deletion flygym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .core import Parameters, NeuroMechFly
from .simulation import Simulation, SingleFlySimulation
from .fly import Fly
from .camera import Camera
from .camera import Camera, NeckCamera
from .util import get_data_path, load_config
37 changes: 37 additions & 0 deletions flygym/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _initialize_custom_camera_handling(self, camera_name: str):
"Animat/camera_back",
"Animat/camera_right",
"Animat/camera_left",
"Animat/camera_neck_zoomin",
]

# always add pos update if it is a head camera
Expand Down Expand Up @@ -690,3 +691,39 @@ def _correct_camera_orientation(self, camera_name: str):
camera.pos[2] = camera.pos[2] + 1.0

return camera


class NeckCamera(Camera):
def __init__(self, **kwargs):
assert "camera_id" not in kwargs, "camera_id should not be passed to NeckCamera"
kwargs["camera_id"] = "Animat/camera_neck_zoomin"
super().__init__(**kwargs)

def _update_cam_pos(self, physics: mjcf.Physics, floor_height: float):
pass
# cam = physics.bind(self._cam)
# cam_pos = cam.xpos.copy()
# cam_pos[2] += floor_height
# cam.xpos = cam_pos

def _update_cam_rot(self, physics: mjcf.Physics):
pass
# cam = physics.bind(self._cam)

# fly_z_rot_euler = (
# np.array([self.fly.last_obs["rot"][0], 0.0, 0.0])
# - self.fly.spawn_orientation[::-1]
# - [np.pi / 2, 0, 0]
# )
# # This compensates both for the scipy to mujoco transform (align with y is
# # [0, 0, 0] in mujoco but [pi/2, 0, 0] in scipy) and the fact that the fly
# # orientation is already taken into account in the base_camera_rot (see below)
# # camera is always looking along its -z axis
# cam_matrix = R.from_euler(
# "yxz", fly_z_rot_euler
# ).as_matrix() # apply the rotation along the y axis of the cameras
# cam_matrix = self.base_camera_rot @ cam_matrix
# cam.xmat = cam_matrix.flatten()

def render(self, physics: mjcf.Physics, floor_height: float, curr_time: float):
return super().render(physics, floor_height, curr_time)
2 changes: 1 addition & 1 deletion flygym/data/mjcf/neuromechfly_seqik_kinorder_ypr.xml
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@
<!-- arctan(5/8) = 0.559 -1.57 - 0.559 = -2.129 -->
<camera name="camera_head_zoomin" class="nmf" mode="track" ipd="0.068" pos="3 -3 1" euler="1.57 0 0.72" fovy="30"/>
<camera name="camera_front_zoomin" class="nmf" mode="track" ipd="0.068" pos="8 0 1" euler="1.57 0 1.57" fovy="15"/>
<camera name="camera_neck_zoomin" class="nmf" mode="targetbodycom" ipd="0.068" pos="0.5 2 0.7" euler="-1.57 3.14 0" target="Thorax"/>
<camera name="camera_neck_zoomin" class="nmf" mode="fixed" ipd="0.068" pos="0.5 2 1.2" euler="-1.57 3.14 0"/>
</worldbody>
<actuator>
<position name="actuator_position_joint_Head_yaw" class="nmf" forcelimited="true" ctrlrange="-1000000 1000000" forcerange="-inf inf" joint="joint_Head_yaw" kp="0.9"/>
Expand Down
2 changes: 1 addition & 1 deletion flygym/examples/head_stabilization/check_videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_last_frame(video_file: Path):
if __name__ == "__main__":
base_dir = Path("./outputs/head_stabilization/random_exploration/")
last_frames = {
path.parent.name: get_last_frame(path) for (path) in base_dir.glob("*/*.mp4")
path.parent.name: get_last_frame(path) for path in base_dir.glob("*/*.mp4")
}

num_images = len(last_frames)
Expand Down
149 changes: 115 additions & 34 deletions flygym/examples/head_stabilization/closed_loop_deployment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pathlib import Path
from tqdm import trange
from flygym import Fly, Camera
from flygym import Fly, Camera, NeckCamera
from flygym.vision import Retina
from flygym.arena import BaseArena, FlatTerrain, BlocksTerrain
from typing import Optional
Expand All @@ -12,16 +12,15 @@
import flygym.examples.head_stabilization.viz as viz
from flygym.examples.vision_connectome_model import NMFRealisticVision, RetinaMapper
from flygym.examples.head_stabilization import HeadStabilizationInferenceWrapper
from flygym.examples.head_stabilization import get_head_stabilization_model_paths


contact_sensor_placements = [
f"{leg}{segment}"
for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]
output_dir = Path("./outputs/head_stabilization/videos/")
output_dir.mkdir(exist_ok=True, parents=True)
output_dir = Path("./outputs/head_stabilization/")
(output_dir / "videos").mkdir(exist_ok=True, parents=True)

# If you trained the models yourself (by running ``collect_training_data.py``
# followed by ``train_proprioception_model.py``), you can use the following
Expand All @@ -33,13 +32,18 @@

# Alternatively, you can use the pre-trained models that come with the
# package. To do so, comment out the three lines above and uncomment the
# following line.
# following 2 lines.
# from flygym.examples.head_stabilization import get_head_stabilization_model_paths
# stabilization_model_path, scaler_param_path = get_head_stabilization_model_paths()

# Simulation parameters

run_time = 1.5


def run_simulation(
arena: BaseArena,
run_time: float = 1.0,
run_time: float = 0.5,
head_stabilization_model: Optional[HeadStabilizationInferenceWrapper] = None,
):
fly = Fly(
Expand All @@ -51,38 +55,54 @@ def run_simulation(
head_stabilization_model=head_stabilization_model,
)

cameras = [
Camera(
fly=fly,
camera_id="Animat/camera_top_zoomout",
play_speed=0.2,
window_size=(600, 600),
fps=24,
play_speed_text=False,
),
Camera(
fly=fly,
camera_id="Animat/camera_neck_zoomin",
play_speed=0.2,
window_size=(600, 600),
fps=24,
play_speed_text=False,
),
]
birdeye_camera = Camera(
fly=fly,
camera_id="Animat/camera_top_zoomout",
play_speed=0.2,
window_size=(600, 600),
fps=24,
play_speed_text=False,
)
birdeye_camera._cam.pos -= np.array([0, 0, 20.0])

neck_camera = NeckCamera(
fly=fly,
play_speed=0.2,
fps=24,
window_size=(600, 600),
camera_follows_fly_orientation=True,
play_speed_text=False,
)

sim = NMFRealisticVision(
fly=fly,
cameras=cameras,
cameras=[birdeye_camera, neck_camera],
arena=arena,
)

sim.reset(seed=0)

# These are only updated when a frame is rendered. They are used for
# generating the summary video at the end of the simulation. Each
# element in the list corresponds to a frame in the video.
birdeye_snapshots = []
zoomin_snapshots = []
raw_vision_snapshots = []
nn_activities_snapshots = []
neck_actuation_viz_vars = []

# These are updated at every time step and are used for generating
# statistics and plots (except vision_all, which is updated every
# time step where the visual input is updated. Visual updates are less
# frequent than physics steps).
head_rotation_hist = []
thorax_rotation_hist = []
neck_actuation_pred_hist = []
neck_actuation_true_hist = []
vision_all = [] # (only updated when visual input is updated)

thorax_body = fly.model.find("body", "Thorax")
head_body = fly.model.find("body", "Head")

# Main simulation loop
for i in trange(int(run_time / sim.timestep)):
Expand All @@ -95,17 +115,39 @@ def run_simulation(
# Record neck actuation for stats at the end of the simulation
if head_stabilization_model is not None:
neck_actuation_pred_hist.append(info["neck_actuation"])
quat = sim.physics.bind(fly.thorax).xquat
quat_inv = transformations.quat_inv(quat)
roll, pitch, _ = transformations.quat_to_euler(quat_inv, ordering="XYZ")
neck_actuation_true_hist.append(np.array([roll, pitch]))
quat = sim.physics.bind(fly.thorax).xquat
quat_inv = transformations.quat_inv(quat)
roll, pitch, _ = transformations.quat_to_euler(quat_inv, ordering="XYZ")
neck_actuation_true_hist.append(np.array([roll, pitch]))

# Record head and thorax orientation
thorax_rotation_quat = sim.physics.bind(thorax_body).xquat
thorax_roll, thorax_pitch, _ = transformations.quat_to_euler(
thorax_rotation_quat, ordering="XYZ"
)
thorax_rotation_hist.append([thorax_roll, thorax_pitch])
head_rotation_quat = sim.physics.bind(head_body).xquat
head_roll, head_pitch, _ = transformations.quat_to_euler(
head_rotation_quat, ordering="XYZ"
)
head_rotation_hist.append([head_roll, head_pitch])

rendered_images = sim.render()
if rendered_images[0] is not None:
birdeye_snapshots.append(rendered_images[0])
zoomin_snapshots.append(rendered_images[1])
raw_vision_snapshots.append(obs["vision"])
nn_activities_snapshots.append(info["nn_activities"])
neck_act = np.zeros(2)
if head_stabilization_model is not None:
neck_act = info["neck_actuation"]
neck_signals = np.hstack(
[np.rad2deg([roll, pitch]), np.rad2deg(neck_act), [sim.curr_time]]
)
neck_actuation_viz_vars.append(neck_signals)

if info["vision_updated"]:
vision_all.append(obs["vision"])

# Generate performance stats on head stabilization
if head_stabilization_model is not None:
Expand All @@ -121,14 +163,28 @@ def run_simulation(
}
else:
r2_scores = None
neck_actuation_true_hist = np.array(neck_actuation_true_hist)
neck_actuation_pred_hist = np.zeros_like(neck_actuation_true_hist)

# Compute standard deviation of each ommatidium's intensity
vision_all = np.array(vision_all).sum(axis=-1) # sum over both channels
vision_std = np.std(vision_all, axis=0)
vision_std_raster = fly.retina.hex_pxls_to_human_readable(vision_std.T)
vision_std_raster[fly.retina.ommatidia_id_map == 0, :] = np.nan

return {
"sim": sim,
"birdeye": birdeye_snapshots,
"zoomin": zoomin_snapshots,
"raw_vision": raw_vision_snapshots,
"nn_activities": nn_activities_snapshots,
"neck_true": neck_actuation_true_hist,
"neck_pred": neck_actuation_pred_hist,
"neck_actuation": neck_actuation_viz_vars,
"r2_scores": r2_scores,
"head_rotation_hist": np.array(head_rotation_hist),
"thorax_rotation_hist": np.array(thorax_rotation_hist),
"vision_std": vision_std_raster,
}


Expand Down Expand Up @@ -157,7 +213,7 @@ def process_trial(terrain_type: str, stabilization_on: bool, cell: str):
if terrain_type == "flat":
arena = FlatTerrain()
elif terrain_type == "blocks":
arena = BlocksTerrain(height_range=(0.2, 0.2))
arena = BlocksTerrain(height_range=(0.2, 0.2), x_range=(-5, 35))
else:
raise ValueError("Invalid terrain type")

Expand All @@ -172,7 +228,7 @@ def process_trial(terrain_type: str, stabilization_on: bool, cell: str):

# Run simulation
sim_res = run_simulation(
arena=arena, run_time=1.0, head_stabilization_model=stabilization_model
arena=arena, run_time=run_time, head_stabilization_model=stabilization_model
)
print(
f"Terrain type {terrain_type}, stabilization {stabilization_on} completed "
Expand All @@ -192,6 +248,10 @@ def process_trial(terrain_type: str, stabilization_on: bool, cell: str):
"zoomin": sim_res["zoomin"],
"raw_vision": raw_vision_hist,
"cell_response": cell_response_hist,
"head_rotation": sim_res["head_rotation_hist"],
"thorax_rotation": sim_res["thorax_rotation_hist"],
"neck_actuation": sim_res["neck_actuation"],
"vision_std": sim_res["vision_std"],
}


Expand All @@ -206,11 +266,12 @@ def process_trial(terrain_type: str, stabilization_on: bool, cell: str):
]
res_all = Parallel(n_jobs=-2)(delayed(process_trial)(*config) for config in configs)
res_all = {k[:2]: v for k, v in zip(configs, res_all)}
# res_all = {config[:2]: process_trial(*config) for config in configs}

# Make summary video
data = {}
for stabilization_on in [True, False]:
for view in ["birdeye", "zoomin", "raw_vision", "cell_response"]:
for view in ["birdeye", "zoomin", "raw_vision", "neck_actuation"]:
# Start with flat terrain
frames = res_all[("flat", stabilization_on)][view]

Expand All @@ -220,8 +281,28 @@ def process_trial(terrain_type: str, stabilization_on: bool, cell: str):

# Switch to blocks terrain
frames += res_all[("blocks", stabilization_on)][view]

data[(stabilization_on, view)] = frames
viz.closed_loop_comparison_video(
data, "T4a", 24, output_dir / "closed_loop_comparison.mp4"
data, 24, output_dir / "videos/closed_loop_comparison.mp4", run_time
)

# Plot example head and thorax rotation time series
rotation_data = {}
for terrain_type in ["flat", "blocks"]:
rotation_data[terrain_type] = {
body: res_all[(terrain_type, True)][f"{body}_rotation"]
for body in ["head", "thorax"]
}
viz.plot_rotation_time_series(
rotation_data, output_dir / "figs/rotation_time_series.pdf"
)

# Plot standard deviation of intensity per ommatidium with and
# without head stabilization
std_data = {}
for terrain_type in ["flat", "blocks"]:
for stabilization_on in [True, False]:
std_data[(terrain_type, stabilization_on)] = res_all[
(terrain_type, stabilization_on)
]["vision_std"]
viz.plot_activities_std(std_data, output_dir / "figs/vision_std.pdf")
10 changes: 5 additions & 5 deletions flygym/examples/head_stabilization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import Dataset
from torchmetrics.regression import R2Score
from pathlib import Path
from typing import Optional, Callable
from typing import Tuple, Optional, Callable


class JointAngleScaler:
Expand All @@ -33,7 +33,7 @@ class WalkingDataset(Dataset):
def __init__(
self,
sim_data_file: Path,
contact_force_thr: float = 3,
contact_force_thr: Tuple[float, float, float] = (0.5, 1, 3),
joint_angle_scaler: Optional[Callable] = None,
ignore_first_n: int = 200,
joint_mask=None,
Expand All @@ -45,7 +45,7 @@ def __init__(
self.terrain = terrain
self.subset = subset
self.dn_drive = f"{dn_left}_{dn_right}"
self.contact_force_thr = contact_force_thr
self.contact_force_thr = np.array([*contact_force_thr, *contact_force_thr])
self.joint_angle_scaler = joint_angle_scaler
self.ignore_first_n = ignore_first_n
self.joint_mask = joint_mask
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(
self,
model_path: Path,
scaler_param_path: Path,
contact_force_thr: float = 3.0,
contact_force_thr: Tuple[float, float, float] = (0.5, 1, 3),
):
# Load scaler params
with open(scaler_param_path, "rb") as f:
Expand All @@ -163,7 +163,7 @@ def __init__(
self.model = ThreeLayerMLP.load_from_checkpoint(
model_path, map_location=torch.device("cpu")
)
self.contact_force_thr = contact_force_thr
self.contact_force_thr = np.array([*contact_force_thr, *contact_force_thr])

def __call__(
self, joint_angles: np.ndarray, contact_forces: np.ndarray
Expand Down
Loading

0 comments on commit d46c5d5

Please sign in to comment.