diff --git a/README.md b/README.md new file mode 100644 index 0000000..4cc91c2 --- /dev/null +++ b/README.md @@ -0,0 +1,90 @@ +# ACT: Action Chunking with Transformers + +### Project Website: https://tonyzhaozh.github.io/aloha/ + +This repo contains the implementation of ACT, together with 2 simulated environments: +Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim (tested) or real (ongoing). + + +### Repo Structure +- ``imitate_episodes.py`` Train and Evaluate ACT +- ``policy.py`` An adaptor for ACT policy +- ``detr`` Model definitions of ACT, modified from DETR +- ``sim_env.py`` Mujoco + DM_Control environments with joint space control +- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control +- ``scripted_policy.py`` Scripted policies for sim environments +- ``constants.py`` Constants shared across files +- ``utils.py`` Utils such as data loading and helper functions +- ``visualize_episodes.py`` Save videos from a .hdf5 dataset + + +### Installation + + conda create -n aloha python=3.8 + conda activate aloha + pip install torchvision + pip install torch + pip install pyquaternion + pip install pyyaml + pip install rospkg + pip install pexpect + pip install mujoco + pip install dm_control + pip install opencv-python + pip install matplotlib + pip install einops + pip install packaging + pip install h5py + pip install h5py_cache + cd act/detr && pip install -e . + +### Example Usages + +To set up a new terminal, run: + + conda activate aloha + cd + +### Simulated experiments + +We use ``transfer_cube`` task in the examples below. Another option is ``insertion``. +To generated 50 episodes of scripted data, run: + + python3 record_sim_episodes.py \ + --task_name transfer_cube \ + --dataset_dir \ + --num_episodes 50 + +To can add the flag ``--onscreen_render`` to see real-time rendering. +To visualize the episode after it is collected, run + + python3 visualize_episodes.py --dataset_dir --episode_idx 0 + +To train ACT: + + # Transfer Cube task + python3 imitate_episodes.py \ + --dataset_dir \ + --ckpt_dir \ + --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ + --task_name transfer_cube --seed 0 \ + --temporal_agg \ + --num_epochs 1000 --lr 1e-4 + + # Bimanual Insertion task + python3 imitate_episodes.py \ + --dataset_dir \ + --ckpt_dir \ + --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ + --task_name insertion --seed 0 \ + --temporal_agg \ + --num_epochs 2000 --lr 1e-5 + +To evaluate the policy, run the same command but add ``--eval``. The success rate +should be around 85% for transfer cube, and around 50% for insertion. +Videos will be saved to ```` for each rollout. +You can also add ``--onscreen_render`` to see real-time rendering during evaluation. + + + + diff --git a/assets/bimanual_viperx_ee_insertion.xml b/assets/bimanual_viperx_ee_insertion.xml new file mode 100644 index 0000000..700aaac --- /dev/null +++ b/assets/bimanual_viperx_ee_insertion.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/bimanual_viperx_ee_transfer_cube.xml b/assets/bimanual_viperx_ee_transfer_cube.xml new file mode 100644 index 0000000..2589384 --- /dev/null +++ b/assets/bimanual_viperx_ee_transfer_cube.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/bimanual_viperx_insertion.xml b/assets/bimanual_viperx_insertion.xml new file mode 100644 index 0000000..f701d70 --- /dev/null +++ b/assets/bimanual_viperx_insertion.xml @@ -0,0 +1,53 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/bimanual_viperx_transfer_cube.xml b/assets/bimanual_viperx_transfer_cube.xml new file mode 100644 index 0000000..bdc9e64 --- /dev/null +++ b/assets/bimanual_viperx_transfer_cube.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/scene.xml b/assets/scene.xml new file mode 100644 index 0000000..5f596bf --- /dev/null +++ b/assets/scene.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/tabletop.stl b/assets/tabletop.stl new file mode 100644 index 0000000..ab35cdf Binary files /dev/null and b/assets/tabletop.stl differ diff --git a/assets/vx300s_10_custom_finger_left.stl b/assets/vx300s_10_custom_finger_left.stl new file mode 100644 index 0000000..534c7af Binary files /dev/null and b/assets/vx300s_10_custom_finger_left.stl differ diff --git a/assets/vx300s_10_custom_finger_right.stl b/assets/vx300s_10_custom_finger_right.stl new file mode 100644 index 0000000..d6a492c Binary files /dev/null and b/assets/vx300s_10_custom_finger_right.stl differ diff --git a/assets/vx300s_10_gripper_finger.stl b/assets/vx300s_10_gripper_finger.stl new file mode 100644 index 0000000..d6df86b Binary files /dev/null and b/assets/vx300s_10_gripper_finger.stl differ diff --git a/assets/vx300s_11_ar_tag.stl b/assets/vx300s_11_ar_tag.stl new file mode 100644 index 0000000..193014b Binary files /dev/null and b/assets/vx300s_11_ar_tag.stl differ diff --git a/assets/vx300s_1_base.stl b/assets/vx300s_1_base.stl new file mode 100644 index 0000000..5a7efda Binary files /dev/null and b/assets/vx300s_1_base.stl differ diff --git a/assets/vx300s_2_shoulder.stl b/assets/vx300s_2_shoulder.stl new file mode 100644 index 0000000..dc22aa7 Binary files /dev/null and b/assets/vx300s_2_shoulder.stl differ diff --git a/assets/vx300s_3_upper_arm.stl b/assets/vx300s_3_upper_arm.stl new file mode 100644 index 0000000..111c586 Binary files /dev/null and b/assets/vx300s_3_upper_arm.stl differ diff --git a/assets/vx300s_4_upper_forearm.stl b/assets/vx300s_4_upper_forearm.stl new file mode 100644 index 0000000..8170d21 Binary files /dev/null and b/assets/vx300s_4_upper_forearm.stl differ diff --git a/assets/vx300s_5_lower_forearm.stl b/assets/vx300s_5_lower_forearm.stl new file mode 100644 index 0000000..39581f8 Binary files /dev/null and b/assets/vx300s_5_lower_forearm.stl differ diff --git a/assets/vx300s_6_wrist.stl b/assets/vx300s_6_wrist.stl new file mode 100644 index 0000000..ab8423e Binary files /dev/null and b/assets/vx300s_6_wrist.stl differ diff --git a/assets/vx300s_7_gripper.stl b/assets/vx300s_7_gripper.stl new file mode 100644 index 0000000..043db9c Binary files /dev/null and b/assets/vx300s_7_gripper.stl differ diff --git a/assets/vx300s_8_gripper_prop.stl b/assets/vx300s_8_gripper_prop.stl new file mode 100644 index 0000000..36099b4 Binary files /dev/null and b/assets/vx300s_8_gripper_prop.stl differ diff --git a/assets/vx300s_9_gripper_bar.stl b/assets/vx300s_9_gripper_bar.stl new file mode 100644 index 0000000..eba3caa Binary files /dev/null and b/assets/vx300s_9_gripper_bar.stl differ diff --git a/assets/vx300s_dependencies.xml b/assets/vx300s_dependencies.xml new file mode 100644 index 0000000..c75d3ad --- /dev/null +++ b/assets/vx300s_dependencies.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/vx300s_left.xml b/assets/vx300s_left.xml new file mode 100644 index 0000000..61e6219 --- /dev/null +++ b/assets/vx300s_left.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/vx300s_right.xml b/assets/vx300s_right.xml new file mode 100644 index 0000000..2c6f007 --- /dev/null +++ b/assets/vx300s_right.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/commands.txt b/commands.txt new file mode 100644 index 0000000..b95b803 --- /dev/null +++ b/commands.txt @@ -0,0 +1,1038 @@ + + + +# ROS terminal +ros-init +roslaunch aloha 4arms_teleop.launch + +# Right hand terminal +conda activate aloha +cd-ps +python3 one_side_teleop.py right + +# Left hand terminal +conda activate aloha +cd-ps +python3 one_side_teleop.py left + +# Sleep terminal +conda activate aloha +cd-ps +python3 sleep.py + +# To recompile +ros-init +catkin_make + + + +python3 record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/test_new --episode_idx 0 + +python3 visualize_episodes.py --dataset_dir /scr2/tonyzhao/datasets/test_new --episode_idx 0 + +python3 replay_episodes.py --dataset_dir /scr2/tonyzhao/datasets/test_new --episode_idx 0 + +python3 imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_25_battery \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_25_battery-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 90 --hidden_dim 512 --batch_size 4 --dim_feedforward 3000 \ +--task_name battery --seed 0 --eval + +##################################################################### +##################################################################### +##################################################################### + +# record sim episodes + +python3 record_sim_episodes.py --task_name transfer_cube --dataset_dir /scr2/tonyzhao/datasets/test_transfer --num_episodes 50 +python3 record_sim_episodes.py --task_name insertion --dataset_dir /scr2/tonyzhao/datasets/test_insertion --num_episodes 50 + + +python3 visualize_episodes.py --dataset_dir /scr2/tonyzhao/datasets/test_transfer --episode_idx 0 + + +python3 imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr2/tonyzhao/train_logs/test_transfer \ +--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 \ +--num_epochs 1000 --lr 1e-4 +# GOOD + +python3 imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_9_sim_insertion \ +--ckpt_dir /scr2/tonyzhao/train_logs/test_insertion \ +--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name insertion --seed 0 \ +--num_epochs 1000 --lr 1e-4 + + +python3 imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_9_sim_insertion \ +--ckpt_dir /scr2/tonyzhao/train_logs/test_insertion2 \ +--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name insertion --seed 0 \ +--num_epochs 2000 --lr 1e-4 + + +python3 imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_9_sim_insertion \ +--ckpt_dir /scr2/tonyzhao/train_logs/test_insertion3 \ +--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name insertion --seed 0 \ +--num_epochs 2000 --lr 1e-5 +# 48% +# 54% with TA + + +python3 imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr2/tonyzhao/train_logs/test_transfer3 \ +--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 \ +--num_epochs 2000 --lr 1e-5 + + + +# MLP + +python3 imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr2/tonyzhao/train_logs/test_transfer-mlp \ +--policy_class CNNMLP --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 \ +--num_epochs 1000 --lr 1e-4 +# GOOD + + + +##################################################################### +##################################################################### +##################################################################### + + +# training bc +conda activate fm +cd-ps +python3 examples/imitate_episodes.py --ckpt_dir /home/tonyzhao/train_logs/10_18_ovefit_GPT/ --eval --onscreen_render + + +# sim and collision detection + +cd /home/tonyzhao/mujoco-2.2.1/bin +./simulate /home/tonyzhao/interbotix_ws/src/interbotix_ros_manipulators/interbotix_ros_xsarms/interbotix_xsarm_descriptions/urdf/bimanual_viperx.xml + + +# run experiments in cluster + +export MUJOCO_GL=osmesa +cd /afs/cs.stanford.edu/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples +python3 imitate_episodes.py --ckpt_dir=/iris/u/tonyzhao/train_logs/11_1_detr_cluster + + + +# WIP + +# fix usb port for robots +udevadm info --name=/dev/ttyUSB0 --attribute-walk | grep serial + +SUBSYSTEM=="tty", ATTRS{serial}=="FT6S4DSP", ATTRS{idVendor}=="0403", ATTRS{idProduct}=="6014", ENV{ID_MM_DEVICE_IGNORE}="1", ATTR{device/latency_timer}="1", SYMLINK+="ttyDXL0" +SUBSYSTEM=="tty", ATTRS{serial}=="FT6S4HW3", ATTRS{idVendor}=="0403", ATTRS{idProduct}=="6014", ENV{ID_MM_DEVICE_IGNORE}="1", ATTR{device/latency_timer}="1", SYMLINK+="ttyDXL1" +SUBSYSTEM=="tty", ATTRS{serial}=="FT4NQ4YH", ATTRS{idVendor}=="0403", ATTRS{idProduct}=="6014", ENV{ID_MM_DEVICE_IGNORE}="1", ATTR{device/latency_timer}="1", SYMLINK+="ttyDXL2" +SUBSYSTEM=="tty", ATTRS{serial}=="FT6S4DOU", ATTRS{idVendor}=="0403", ATTRS{idProduct}=="6014", ENV{ID_MM_DEVICE_IGNORE}="1", ATTR{device/latency_timer}="1", SYMLINK+="ttyDXL3" + +# fix usb port for cameras +udevadm info --name=/dev/video0 --attribute-walk | grep serial + +SUBSYSTEM=="video4linux", ATTRS{serial}=="C58A5FAF", ATTR{index}=="0", ATTRS{idProduct}=="085c", ATTR{device/latency_timer}="1", SYMLINK+="CAM_RIGHT_WRIST" +SUBSYSTEM=="video4linux", ATTRS{serial}=="7FDB4B6F", ATTR{index}=="0", ATTRS{idProduct}=="085c", ATTR{device/latency_timer}="1", SYMLINK+="CAM_HIGH" +SUBSYSTEM=="video4linux", ATTRS{serial}=="0E1A2B2F", ATTR{index}=="0", ATTRS{idProduct}=="085c", ATTR{device/latency_timer}="1", SYMLINK+="CAM_LEFT_WRIST" +SUBSYSTEM=="video4linux", ATTRS{serial}=="98ED30BF", ATTR{index}=="0", ATTRS{idProduct}=="085c", ATTR{device/latency_timer}="1", SYMLINK+="CAM_LOW" + +# add these commands and reload usb connections +sudo vim /etc/udev/rules.d/99-fixed-interbotix-udev.rules +sudo udevadm control --reload +sudo udevadm trigger + + +# Installation + +Install interbotix and ROS. Ubuntu 18.01 or 20.01 +https://www.trossenrobotics.com/docs/interbotix_xsarms/ros_interface/software_setup.html + +Install oculus reader +https://github.com/rail-berkeley/oculus_reader +After following the first few installations +Clone the repo +pip install -e + +Follow instruction, and use apt get to install ros-noetic-* + + +# installing usb camera +sudo apt-get install ros-noetic-usb-cam +sudo apt-get install ros-noetic-cv-bridge + + +conda activate fm +pip install torchvision +pip install torch +pip install pyquaternion +pip install pyyaml +pip install rospkg +pip install pexpect +pip install mujoco +pip install dm_control +pip install opencv-python +pip install matplotlib +pip install einops +pip install packaging +pip install h5py + +### Experiments + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_10_velcro_thread_win_30_kl_5 --kl_weight 5 --window_len 30 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_10_velcro_thread_win_30_kl_10 --kl_weight 10 --window_len 30 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_10_velcro_thread_win_30_kl_20 --kl_weight 20 --window_len 30 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_10_velcro_thread_win_30_kl_40 --kl_weight 40 --window_len 30 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_10_velcro_thread_win_15_kl_10 --kl_weight 10 --window_len 15 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_11_velcro_thread_win_60_kl_10 --kl_weight 10 --window_len 60 + + +Dec 21 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_21_velcro_thread_win_60_kl_10 --kl_weight 10 --window_len 60 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_21_velcro_thread_win_30_kl_10 --kl_weight 10 --window_len 30 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_21_velcro_thread_win_60_kl_20 --kl_weight 20 --window_len 60 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_21_velcro_thread_win_30_kl_20 --kl_weight 20 --window_len 30 --eval + +Dec 22 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_22_velcro_thread_win_60_kl_10 --kl_weight 10 --window_len 60 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_22_velcro_thread_win_60_kl_20 --kl_weight 20 --window_len 60 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_22_velcro_thread_win_60_kl_40 --kl_weight 40 --window_len 60 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_22_velcro_thread_win_30_kl_10 --kl_weight 10 --window_len 30 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_22_velcro_thread_win_30_kl_40 --kl_weight 40 --window_len 30 --eval + +round 2 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/12_22_velcro_thread_win_60_kl_10 --kl_weight 5 --window_len 60 + +# human data + detrvae + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_try1 --kl_weight 10 --window_len 60 +Success rate: 0.14 +Average return: 22.04 + + +try 512 hidden_dim +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_try1_expressive --kl_weight 10 --window_len 60 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_try1_expressive2 --kl_weight 10 --window_len 50 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_try2 --kl_weight 10 --window_len 400 + + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_try3_expressive --kl_weight 10 --window_len 200 --eval + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_try4_expressive --kl_weight 10 --window_len 100 --eval + + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_try1_20kl --kl_weight 20 --window_len 60 # --hidden_dim 256 --batch_size 4 +# 0.32 + + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try1 --kl_weight 20 --window_len 60 --hidden_dim 256 --batch_size 32 +Success rate: 0.02 +Average return: 2.32 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try2 --kl_weight 30 --window_len 60 --hidden_dim 256 --batch_size 32 +Success rate: 0.02 +Average return: 2.22 + + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try3 --kl_weight 20 --window_len 60 --hidden_dim 512 --batch_size 16 +Success rate: 0.18 +Average return: 22.8 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try4 --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 16 +Success rate: 0.38 +Average return: 33.16 + + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try5 --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 +Success rate: 0.04 +Average return: 4.1 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try6 --kl_weight 40 --window_len 60 --hidden_dim 512 --batch_size 16 +Success rate: 0.26 +Average return: 33.18 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try7 --kl_weight 30 --window_len 60 --hidden_dim 1024 --batch_size 8 +Success rate: 0.04 +Average return: 5.9 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_1_sim_box_human_detrvae_sweep_try8 --kl_weight 40 --window_len 60 --hidden_dim 1024 --batch_size 8 +Success rate: 0.02 +Average return: 0.42 + + +# new task insertion + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_10_sim_insertion_detrvae_sweep_try1 --kl_weight 10 --window_len 400 --hidden_dim 256 --batch_size 8 +Success rate: 0.0 +Average return: 0.0 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_10_sim_insertion_detrvae_sweep_try2 --kl_weight 10 --window_len 200 --hidden_dim 256 --batch_size 8 +Success rate: 0.15 +Average return: 1.95 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_10_sim_insertion_detrvae_sweep_try3 --kl_weight 10 --window_len 100 --hidden_dim 256 --batch_size 8 +Success rate: 0.4 +Average return: 16.7 + +python3 examples/imitate_episodes.py --ckpt_dir /scr/tonyzhao/train_logs/1_10_sim_insertion_detrvae_sweep_try4 --kl_weight 10 --window_len 50 --hidden_dim 256 --batch_size 8 +Success rate: 0.2 +Average return: 9.8 + + + +### Ours - experiments and time ensemble + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 16 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 +# 28 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-2 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 +# 38 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-3 \ +--policy_class DETRVAE --kl_weight 30 --window_len 120 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 +# BAD + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-4 \ +--policy_class DETRVAE --kl_weight 30 --window_len 30 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 +# BAD + +# maybe just train more? +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-5 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 +# 30 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-6 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 1024 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 +# bad + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-7 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 6144 \ +--task_name transfer_cube --seed 0 +# bad + + +######### + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-Tweak-seed-0 \ +--policy_class DETRVAE --kl_weight 40 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 --eval + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-Tweak-seed-1 \ +--policy_class DETRVAE --kl_weight 40 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 1 --eval + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-Tweak-seed-2 \ +--policy_class DETRVAE --kl_weight 40 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 2 --eval + +### MAJOR REGRESSION + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win400-seed-0-fix-regression \ +--policy_class DETRVAE --kl_weight 10 --window_len 400 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 + + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win400-seed-1-fix-regression \ +--policy_class DETRVAE --kl_weight 10 --window_len 400 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 1 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win400-seed-2-fix-regression \ +--policy_class DETRVAE --kl_weight 10 --window_len 400 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 2 + +### TODO TODO try with and without time ensemble! + + +### Retry with 200 window + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-1 \ +--policy_class DETRVAE --kl_weight 10 --window_len 200 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 --eval + +Reward >= 0: 50/50 = 100.0% +Reward >= 1: 42/50 = 84.0% +Reward >= 2: 40/50 = 80.0% +Reward >= 3: 36/50 = 72.0% +Reward >= 4: 36/50 = 72.0% + +After time ensemble +Reward >= 0: 50/50 = 100.0% +Reward >= 1: 48/50 = 96.0% +Reward >= 2: 46/50 = 92.0% +Reward >= 3: 45/50 = 90.0% +Reward >= 4: 45/50 = 90.0% + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-1-seed-1 \ +--policy_class DETRVAE --kl_weight 10 --window_len 200 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 1 +Reward >= 0: 50/50 = 100.0% +Reward >= 1: 38/50 = 76.0% +Reward >= 2: 34/50 = 68.0% +Reward >= 3: 14/50 = 28.000000000000004% +Reward >= 4: 14/50 = 28.000000000000004% + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-1-seed-2 \ +--policy_class DETRVAE --kl_weight 10 --window_len 200 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 2 +Reward >= 0: 50/50 = 100.0% +Reward >= 1: 17/50 = 34.0% +Reward >= 2: 9/50 = 18.0% +Reward >= 3: 2/50 = 4.0% +Reward >= 4: 2/50 = 4.0% + +### + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-3-seed-0 \ +--policy_class DETRVAE --kl_weight 10 --window_len 50 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 +Reward >= 0: 50/50 = 100.0% +Reward >= 1: 41/50 = 82.0% +Reward >= 2: 37/50 = 74.0% +Reward >= 3: 23/50 = 46.0% +Reward >= 4: 23/50 = 46.0% + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-3-seed-1 \ +--policy_class DETRVAE --kl_weight 10 --window_len 50 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 1 +Reward >= 0: 50/50 = 100.0% +Reward >= 1: 32/50 = 64.0% +Reward >= 2: 13/50 = 26.0% +Reward >= 3: 1/50 = 2.0% +Reward >= 4: 1/50 = 2.0% + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-3-seed-2 \ +--policy_class DETRVAE --kl_weight 10 --window_len 50 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 2 +Reward >= 0: 50/50 = 100.0% +Reward >= 1: 45/50 = 90.0% +Reward >= 2: 41/50 = 82.0% +Reward >= 3: 21/50 = 42.0% +Reward >= 4: 21/50 = 42.0% + + +# others +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win400-seed-0-1 \ +--policy_class DETRVAE --kl_weight 10 --window_len 400 --hidden_dim 512 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win400-seed-1 \ +--policy_class DETRVAE --kl_weight 30 --window_len 400 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 1 + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win400-seed-2 \ +--policy_class DETRVAE --kl_weight 30 --window_len 400 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 2 + + + + +# TODO compare with no VAE + +#################################################################################################### +###################################### Paper Experiments ########################################### +#################################################################################################### + +### train 3 seeds of BYOL on scripted and human data + +srun --account=iris -p iris-interactive --mem=20GB --gres=gpu:1 --pty bash + +rsync -ra /scr/tonyzhao/datasets/ tonyzhao@scdt:/iris/u/tonyzhao/datasets + +cd /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning +conda activate tonyz-fm + +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/12_25_sim_transfer_box_top_view --seed 0 # DONE +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/12_25_sim_transfer_box_top_view --seed 1 # DONE +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/12_25_sim_transfer_box_top_view --seed 2 # DONE + +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/12_30_sim_box_50_human --seed 0 # DONE +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/12_30_sim_box_50_human --seed 1 # DONE +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/12_30_sim_box_50_human --seed 2 # DONE + +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/1_9_sim_insertion --seed 0 +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/1_9_sim_insertion --seed 1 +python3 train.py --dataset_dir /iris/u/tonyzhao/datasets/1_9_sim_insertion --seed 2 + + +# sync the checkpoints back +rsync -ra tonyzhao@scdt:/iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/*.pt /scr/tonyzhao/remote_trained + + +### Cache feature for all models +# LOCAL +launch +python3 examples/cache_feature.py --ckpt_path /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-0.pt # DONE +python3 examples/cache_feature.py --ckpt_path /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-1.pt # DONE +python3 examples/cache_feature.py --ckpt_path /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-2.pt # DONE + +python3 examples/cache_feature.py --ckpt_path /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-0.pt # DONE +python3 examples/cache_feature.py --ckpt_path /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-1.pt # DONE +python3 examples/cache_feature.py --ckpt_path /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-2.pt # DONE + + +# REMOTE +cd /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts +conda activate tonyz-fm + +# DONE +python3 examples/cache_feature.py --ckpt_path /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/byol-12_25_sim_transfer_box_top_view-seed-0.pt +python3 examples/cache_feature.py --ckpt_path /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/byol-12_25_sim_transfer_box_top_view-seed-1.pt +python3 examples/cache_feature.py --ckpt_path /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/byol-12_25_sim_transfer_box_top_view-seed-2.pt + +# DONE +python3 examples/cache_feature.py --ckpt_path /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/byol-12_30_sim_box_50_human-seed-0.pt +python3 examples/cache_feature.py --ckpt_path /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/byol-12_30_sim_box_50_human-seed-1.pt +python3 examples/cache_feature.py --ckpt_path /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/byol-12_30_sim_box_50_human-seed-2.pt + + +### VINN + +# Scripted data +# select K +python3 examples/vinn_select_k.py # hardcode parameters + +# LOCAL +mkdir /scr/tonyzhao/train_logs/vinn_12_25_sim_transfer_box_top_view-seed-0 +python3 examples/vinn_eval.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-0.pt \ +--task_name transfer_cube \ +--ckpt_dir /scr/tonyzhao/train_logs/vinn_12_25_sim_transfer_box_top_view-seed-0 +# RECORDED + +mkdir /scr/tonyzhao/train_logs/vinn_12_25_sim_transfer_box_top_view-seed-1 +python3 examples/vinn_eval.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-1.pt \ +--task_name transfer_cube \ +--ckpt_dir /scr/tonyzhao/train_logs/vinn_12_25_sim_transfer_box_top_view-seed-1 +# RECORDED + +mkdir /scr/tonyzhao/train_logs/vinn_12_25_sim_transfer_box_top_view-seed-2 +python3 examples/vinn_eval.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-2.pt \ +--task_name transfer_cube \ +--ckpt_dir /scr/tonyzhao/train_logs/vinn_12_25_sim_transfer_box_top_view-seed-2 +# RECORDED + +# Human data +# select K +python3 examples/vinn_select_k.py # hardcode parameters + +# LOCAL +mkdir /scr/tonyzhao/train_logs/vinn-12_30_sim_box_50_human-seed-0 +python3 examples/vinn_eval.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-0.pt \ +--task_name transfer_cube \ +--ckpt_dir /scr/tonyzhao/train_logs/vinn-12_30_sim_box_50_human-seed-0 +# RECORDED + +mkdir /scr/tonyzhao/train_logs/vinn-12_30_sim_box_50_human-seed-1 +python3 examples/vinn_eval.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-1.pt \ +--task_name transfer_cube \ +--ckpt_dir /scr/tonyzhao/train_logs/vinn-12_30_sim_box_50_human-seed-1 +# RECORDED + +mkdir /scr/tonyzhao/train_logs/vinn-12_30_sim_box_50_human-seed-2 +python3 examples/vinn_eval.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-2.pt \ +--task_name transfer_cube \ +--ckpt_dir /scr/tonyzhao/train_logs/vinn-12_30_sim_box_50_human-seed-2 +# RECORDED + + + +### BET + +# LOCAL + +# Scripted data +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-0.pt \ +--ckpt_dir /scr/tonyzhao/train_logs/bet-12_25_sim_transfer_box_top_view-seed-0 \ +--window_len 100 --n_embd 1500 --state_repeat 10 --seed 0 +# RECORDED + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-1.pt \ +--ckpt_dir /scr/tonyzhao/train_logs/bet-12_25_sim_transfer_box_top_view-seed-1 \ +--window_len 100 --n_embd 1500 --state_repeat 10 --seed 1 +# RECORDED + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_25_sim_transfer_box_top_view-seed-2.pt \ +--ckpt_dir /scr/tonyzhao/train_logs/bet-12_25_sim_transfer_box_top_view-seed-2 \ +--window_len 100 --n_embd 1500 --state_repeat 10 --seed 2 +# RECORDED + +# Human data +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-0.pt \ +--ckpt_dir /scr/tonyzhao/train_logs/bet-12_30_sim_box_50_human-seed-0 \ +--window_len 100 --n_embd 1500 --state_repeat 10 --seed 0 +# RECORDED + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-1.pt \ +--ckpt_dir /scr/tonyzhao/train_logs/bet-12_30_sim_box_50_human-seed-1 \ +--window_len 100 --n_embd 1500 --state_repeat 10 --seed 1 +# RECORDED + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--model_dir /scr/tonyzhao/remote_trained/byol-12_30_sim_box_50_human-seed-2.pt \ +--ckpt_dir /scr/tonyzhao/train_logs/bet-12_30_sim_box_50_human-seed-2 \ +--window_len 100 --n_embd 1500 --state_repeat 10 --seed 2 +# RECORDED + +# REMOTE # not used +python3 examples/imitate_episodes.py \ +--dataset_dir /iris/u/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--model_dir /iris/u/tonyzhao/Research/interbotix-src/interbotix_ros_manipulators/interbotix_ros_xsarms/examples/teleoperation/python_scripts/examples/byol_pytorch/examples/lightning/byol-12_25_sim_transfer_box_top_view-seed-0.pt +--ckpt_dir /iris/u/tonyzhao/train_logs/1_10_sim_box_bet_vision_byol_try2 \ +--window_len 100 --n_embd 1500 --state_repeat 10 --seed 0 + + +### MLP + +# scripted data +# LOCAL +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/CNNMLP_12_25_sim_transfer_box_top_view-seed-0 \ +--policy_class CNNMLP --batch_size 16 --seed 0 --task_name transfer_cube + +# Remote +python3 examples/imitate_episodes.py \ +--dataset_dir /iris/u/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /iris/u/tonyzhao/train_logs/CNNMLP_12_25_sim_transfer_box_top_view-seed-0 \ +--policy_class CNNMLP --batch_size 16 --seed 0 --task_name transfer_cube + +python3 examples/imitate_episodes.py \ +--dataset_dir /iris/u/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /iris/u/tonyzhao/train_logs/CNNMLP_12_25_sim_transfer_box_top_view-seed-1 \ +--policy_class CNNMLP --batch_size 16 --seed 1 --task_name transfer_cube + +python3 examples/imitate_episodes.py \ +--dataset_dir /iris/u/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /iris/u/tonyzhao/train_logs/CNNMLP_12_25_sim_transfer_box_top_view-seed-2 \ +--policy_class CNNMLP --batch_size 16 --seed 2 --task_name transfer_cube + +# human data + +python3 examples/imitate_episodes.py \ +--dataset_dir /iris/u/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /iris/u/tonyzhao/train_logs/CNNMLP_12_30_sim_box_50_human-seed-0 \ +--policy_class CNNMLP --batch_size 16 --seed 0 --task_name transfer_cube + +python3 examples/imitate_episodes.py \ +--dataset_dir /iris/u/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /iris/u/tonyzhao/train_logs/CNNMLP_12_30_sim_box_50_human-seed-1 \ +--policy_class CNNMLP --batch_size 16 --seed 1 --task_name transfer_cube + +python3 examples/imitate_episodes.py \ +--dataset_dir /iris/u/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /iris/u/tonyzhao/train_logs/CNNMLP_12_30_sim_box_50_human-seed-2 \ +--policy_class CNNMLP --batch_size 16 --seed 2 --task_name transfer_cube + + +### DETR VAE with human data + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-2 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 +# Recorded + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-2-seed-1 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 1 +# Recorded + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_30_sim_box_50_human \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_30_sim_box_50_human-TEST-2-seed-2 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 2 +# Recorded + + +### DETR VAE with scripted data + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win60-seed-0 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 0 --eval +# Recorded # TODO also run time ensemble version? + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win60-seed-1 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 1 +# Recorded # TODO also run time ensemble version? + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_win60-seed-2 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 8 --dim_feedforward 4096 \ +--task_name transfer_cube --seed 2 +# Recorded, with time ensemble + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-2-seed-0 \ +--policy_class DETRVAE --kl_weight 10 --window_len 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 0 +# Recorded + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-2-seed-1 \ +--policy_class DETRVAE --kl_weight 10 --window_len 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 1 +# Recorded + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr/tonyzhao/datasets/12_25_sim_transfer_box_top_view \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_12_25_sim_transfer_box_top_view_sweep-2-seed-2 \ +--policy_class DETRVAE --kl_weight 10 --window_len 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \ +--task_name transfer_cube --seed 2 +# Recorded + + + +############################################################### +### Real robot experiments +############################################################### + +### Ziploc slide + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_20_zip_slide --episode_idx 0 + +# Ours +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_20_zip_slide \ +--ckpt_dir /scr/tonyzhao/train_logs/DETRVAE_1_20_zip_slide-seed-0 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 2 --dim_feedforward 4096 \ +--task_name ziploc_slide --seed 0 +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_20_zip_slide \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_20_zip_slide-seed-0-1 \ +--policy_class DETRVAE --kl_weight 20 --window_len 60 --hidden_dim 256 --batch_size 4 --dim_feedforward 2048 \ +--task_name ziploc_slide --seed 0 +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_20_zip_slide \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_20_zip_slide-seed-0-2 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 256 --batch_size 4 --dim_feedforward 2048 \ +--task_name ziploc_slide --seed 0 +# BEST + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_20_zip_slide \ +--ckpt_dir /scr2/tonyzhao/train_logs/CNNMLP_1_20_zip_slide-seed-0 \ +--policy_class CNNMLP --batch_size 4 \ +--task_name ziploc_slide --seed 0 + + +### Ziploc open + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_21_zipoc-seed-0 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 256 --batch_size 4 --dim_feedforward 2048 \ +--task_name ziploc --seed 0 + +# overnight + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_21_zipoc-seed-0-1 \ +--policy_class DETRVAE --kl_weight 30 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name ziploc --seed 0 --eval + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_21_zipoc-seed-0-2 \ +--policy_class DETRVAE --kl_weight 50 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name ziploc --seed 0 --eval + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_21_zipoc-seed-0-3 \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name ziploc --seed 0 --eval + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_21_zipoc-seed-0-4 \ +--policy_class DETRVAE --kl_weight 100 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name ziploc --seed 0 --eval + +### 1 conv only + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_22_zipoc-seed-0-1conv \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name ziploc --seed 0 + +# just train for longer + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_22_zipoc-seed-0-1conv_longer \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name ziploc --seed 0 + + +rsync -ra tonyzhao@scdt:/iris/u/tonyzhao/train_logs/DETRVAE_1_22_zipoc-seed-0-try1 /scr/tonyzhao/remote_trained/DETRVAE_1_22_zipoc-seed-0-try1 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_21_zipoc \ +--ckpt_dir /scr/tonyzhao/remote_trained/DETRVAE_1_22_zipoc-seed-0-try1 \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 1024 --batch_size 8 --dim_feedforward 4096 \ +--task_name ziploc --seed 0 + + + +Do we need the VAE? + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_22_zipoc_fixed --episode_idx 0 + + +### Condiment cups + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_22_cups_open --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_22_cups_open \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_22_cups_open-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name cup_open --seed 0 --eval + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_22_cups_open \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_22_cups_open-seed-0-1 \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name cup_open --seed 0 --eval + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_22_cups_open \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_22_cups_open-seed-0-2 \ +--policy_class DETRVAE --kl_weight 100 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name cup_open --seed 0 --eval + +### Condiment cups2 + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_23_cups_open --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_23_cups_open \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_23_cups_open-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name cup_open --seed 0 + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_23_cups_open \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_23_cups_open-seed-0-longwindow \ +--policy_class DETRVAE --kl_weight 80 --window_len 90 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name cup_open --seed 0 + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_23_cups_open \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_23_cups_open-seed-0-longer \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name cup_open --seed 0 + + + + +### Battery slotting + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_23_battery --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_23_battery \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_23_battery-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 60 --hidden_dim 512 --batch_size 4 --dim_feedforward 3200 \ +--task_name battery --seed 0 + +### Battery slotting try 2 + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_24_battery --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_24_battery \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_24_battery-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 90 --hidden_dim 512 --batch_size 4 --dim_feedforward 3000 \ +--task_name battery --seed 0 --eval + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_24_battery \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_24_battery-seed-0_try2 \ +--policy_class DETRVAE --kl_weight 80 --window_len 120 --hidden_dim 512 --batch_size 4 --dim_feedforward 3000 \ +--task_name battery --seed 0 --eval + + +### Battery slotting try 3 + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_25_battery --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_25_battery \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_25_battery-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 90 --hidden_dim 512 --batch_size 4 --dim_feedforward 3000 \ +--task_name battery --seed 0 + + + +### Taping + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_26_tape --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_26_tape \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_26_tape-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 90 --hidden_dim 512 --batch_size 4 --dim_feedforward 3000 \ +--task_name tape --seed 0 + + +### Candy! + +python3 examples/record_episodes.py --dataset_dir /scr2/tonyzhao/datasets/1_27_candy --episode_idx 0 + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_27_candy \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_27_candy-seed-0 \ +--policy_class DETRVAE --kl_weight 80 --window_len 90 --hidden_dim 512 --batch_size 4 --dim_feedforward 3000 \ +--task_name tape --seed 0 + + + +python3 examples/imitate_episodes.py \ +--dataset_dir /scr2/tonyzhao/datasets/1_27_candy \ +--ckpt_dir /scr2/tonyzhao/train_logs/DETRVAE_1_27_candy-seed-0-1 \ +--policy_class DETRVAE --kl_weight 80 --window_len 120 --hidden_dim 512 --batch_size 4 --dim_feedforward 3000 \ +--task_name tape --seed 0 + + + + + + + + + diff --git a/constants.py b/constants.py new file mode 100644 index 0000000..e626194 --- /dev/null +++ b/constants.py @@ -0,0 +1,53 @@ +import pathlib + +### Parameters that changes across tasks +EPISODE_LEN = 600 + +### ALOHA fixed constants +DT = 0.02 +JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] +CAMERA_NAMES = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] # defines the number and ordering of cameras +BOX_INIT_POSE = [0.2, 0.5, 0.05, 1, 0, 0, 0] +START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] +SIM_CAMERA_NAMES = ['main'] + +SIM_EPISODE_LEN_TRANSFER_CUBE = 400 +SIM_EPISODE_LEN_INSERTION = 400 + +XML_DIR = str(pathlib.Path(__file__).parent.resolve()) + '/assets/' # note: absolute path + +# Left finger position limits (qpos[7]), right_finger = -1 * left_finger +MASTER_GRIPPER_POSITION_OPEN = 0.02417 +MASTER_GRIPPER_POSITION_CLOSE = 0.01244 +PUPPET_GRIPPER_POSITION_OPEN = 0.05800 +PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 + +# Gripper joint limits (qpos[6]) +MASTER_GRIPPER_JOINT_OPEN = 0.3083 +MASTER_GRIPPER_JOINT_CLOSE = -0.6842 +PUPPET_GRIPPER_JOINT_OPEN = 1.4910 +PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 + +############################ Helper functions ############################ + +MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) +PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) +MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE +PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE +MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) + +MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) +PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) +MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE +PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE +MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) + +MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) +PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + +MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE +MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)) +PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE +PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)) + +MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2 diff --git a/detr/LICENSE b/detr/LICENSE new file mode 100644 index 0000000..b1395e9 --- /dev/null +++ b/detr/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 - present, Facebook, Inc + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/detr/README.md b/detr/README.md new file mode 100644 index 0000000..500b1b8 --- /dev/null +++ b/detr/README.md @@ -0,0 +1,9 @@ +This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0. + + @article{Carion2020EndtoEndOD, + title={End-to-End Object Detection with Transformers}, + author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko}, + journal={ArXiv}, + year={2020}, + volume={abs/2005.12872} + } \ No newline at end of file diff --git a/detr/main.py b/detr/main.py new file mode 100644 index 0000000..213c5fb --- /dev/null +++ b/detr/main.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +from pathlib import Path + +import numpy as np +import torch +from .models import build_ACT_model, build_CNNMLP_model + +import IPython +e = IPython.embed + +def get_args_parser(): + parser = argparse.ArgumentParser('Set transformer detector', add_help=False) + parser.add_argument('--lr', default=1e-4, type=float) # will be overridden + parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden + parser.add_argument('--batch_size', default=2, type=int) # not used + parser.add_argument('--weight_decay', default=1e-4, type=float) + parser.add_argument('--epochs', default=300, type=int) # not used + parser.add_argument('--lr_drop', default=200, type=int) # not used + parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used + help='gradient clipping max norm') + + # Model parameters + # * Backbone + parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden + help="Name of the convolutional backbone to use") + parser.add_argument('--dilation', action='store_true', + help="If true, we replace stride with dilation in the last convolutional block (DC5)") + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features") + parser.add_argument('--camera_names', default=[], type=list, # will be overridden + help="A list of camera names") + + # * Transformer + parser.add_argument('--enc_layers', default=4, type=int, # will be overridden + help="Number of encoding layers in the transformer") + parser.add_argument('--dec_layers', default=6, type=int, # will be overridden + help="Number of decoding layers in the transformer") + parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden + help="Intermediate size of the feedforward layers in the transformer blocks") + parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--dropout', default=0.1, type=float, + help="Dropout applied in the transformer") + parser.add_argument('--nheads', default=8, type=int, # will be overridden + help="Number of attention heads inside the transformer's attentions") + parser.add_argument('--num_queries', default=400, type=int, # will be overridden + help="Number of query slots") + parser.add_argument('--pre_norm', action='store_true') + + # * Segmentation + parser.add_argument('--masks', action='store_true', + help="Train segmentation head if the flag is provided") + + # repeat args in imitate_episodes just to avoid error. Will not be used + parser.add_argument('--eval', action='store_true') + parser.add_argument('--onscreen_render', action='store_true') + parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True) + parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) + parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True) + parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) + parser.add_argument('--seed', action='store', type=int, help='seed', required=True) + parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) + parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) + parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) + parser.add_argument('--temporal_agg', action='store_true') + + return parser + + +def build_ACT_model_and_optimizer(args_override): + parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_ACT_model(args) + model.cuda() + + param_dicts = [ + {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) + + return model, optimizer + + +def build_CNNMLP_model_and_optimizer(args_override): + parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_CNNMLP_model(args) + model.cuda() + + param_dicts = [ + {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) + + return model, optimizer + diff --git a/detr/models/__init__.py b/detr/models/__init__.py new file mode 100644 index 0000000..cc78db1 --- /dev/null +++ b/detr/models/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .detr_vae import build as build_vae +from .detr_vae import build_cnnmlp as build_cnnmlp + +def build_ACT_model(args): + return build_vae(args) + +def build_CNNMLP_model(args): + return build_cnnmlp(args) \ No newline at end of file diff --git a/detr/models/backbone.py b/detr/models/backbone.py new file mode 100644 index 0000000..f28637e --- /dev/null +++ b/detr/models/backbone.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + +import IPython +e = IPython.embed + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/detr/models/detr_vae.py b/detr/models/detr_vae.py new file mode 100644 index 0000000..fc25edf --- /dev/null +++ b/detr/models/detr_vae.py @@ -0,0 +1,275 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +from torch import nn +from torch.autograd import Variable +from .backbone import build_backbone +from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer + +import numpy as np + +import IPython +e = IPython.embed + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DETRVAE(nn.Module): + """ This is the DETR module that performs object detection """ + def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names): + """ Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_proj = nn.Linear(14, hidden_dim) # project state to embedding + self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var + self.register_buffer('pos_table', get_sinusoid_encoding_table(num_queries+1, hidden_dim)) + + # decoder extra parameters + self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_proj(actions) # (bs, seq, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) + encoder_input = torch.cat([cls_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_is_pad = torch.full((bs, 1), False).to(qpos.device) # False: not a padding + is_pad = torch.cat([cls_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, :self.latent_dim] + logvar = latent_info[:, self.latent_dim:] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] + a_hat = self.action_head(hs) + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat, [mu, logvar] + + + +class CNNMLP(nn.Module): + def __init__(self, backbones, state_dim, camera_names): + """ Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.camera_names = camera_names + self.action_head = nn.Linear(1000, state_dim) # TODO add more + if backbones is not None: + self.backbones = nn.ModuleList(backbones) + backbone_down_projs = [] + for backbone in backbones: + down_proj = nn.Sequential( + nn.Conv2d(backbone.num_channels, 128, kernel_size=5), + nn.Conv2d(128, 64, kernel_size=5), + nn.Conv2d(64, 32, kernel_size=5) + ) + backbone_down_projs.append(down_proj) + self.backbone_down_projs = nn.ModuleList(backbone_down_projs) + + mlp_in_dim = 768 * len(backbones) + 14 + self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2) + else: + raise NotImplementedError + + def forward(self, qpos, image, env_state, actions=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + # Image observation features and position embeddings + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id](image[:, cam_id]) + features = features[0] # take the last layer feature + pos = pos[0] # not used + all_cam_features.append(self.backbone_down_projs[cam_id](features)) + # flatten everything + flattened_features = [] + for cam_feature in all_cam_features: + flattened_features.append(cam_feature.reshape([bs, -1])) + flattened_features = torch.cat(flattened_features, axis=1) # 768 each + features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 + a_hat = self.mlp(features) + return a_hat + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] + for i in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] + mods.append(nn.Linear(hidden_dim, output_dim)) + trunk = nn.Sequential(*mods) + return trunk + + +def build_encoder(args): + d_model = args.hidden_dim # 256 + dropout = args.dropout # 0.1 + nhead = args.nheads # 8 + dim_feedforward = args.dim_feedforward # 2048 + num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder + normalize_before = args.pre_norm # False + activation = "relu" + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + return encoder + + +def build(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer(args) + + encoder = build_encoder(args) + + model = DETRVAE( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters/1e6,)) + + return model + +def build_cnnmlp(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for _ in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + + model = CNNMLP( + backbones, + state_dim=state_dim, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters/1e6,)) + + return model + diff --git a/detr/models/position_encoding.py b/detr/models/position_encoding.py new file mode 100644 index 0000000..209d917 --- /dev/null +++ b/detr/models/position_encoding.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + +import IPython +e = IPython.embed + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/detr/models/transformer.py b/detr/models/transformer.py new file mode 100644 index 0000000..f38afd0 --- /dev/null +++ b/detr/models/transformer.py @@ -0,0 +1,314 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +import IPython +e = IPython.embed + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + hs = hs.transpose(1, 2) + return hs + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/detr/setup.py b/detr/setup.py new file mode 100644 index 0000000..55d18c0 --- /dev/null +++ b/detr/setup.py @@ -0,0 +1,10 @@ +from distutils.core import setup +from setuptools import find_packages + +setup( + name='detr', + version='0.0.0', + packages=find_packages(), + license='MIT License', + long_description=open('README.md').read(), +) \ No newline at end of file diff --git a/detr/util/__init__.py b/detr/util/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/detr/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/detr/util/box_ops.py b/detr/util/box_ops.py new file mode 100644 index 0000000..9c088e5 --- /dev/null +++ b/detr/util/box_ops.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/detr/util/misc.py b/detr/util/misc.py new file mode 100644 index 0000000..dfa9fb5 --- /dev/null +++ b/detr/util/misc.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from packaging import version +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if version.parse(torchvision.__version__) < version.parse('0.7'): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse('0.7'): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/detr/util/plot_utils.py b/detr/util/plot_utils.py new file mode 100644 index 0000000..0f24bed --- /dev/null +++ b/detr/util/plot_utils.py @@ -0,0 +1,107 @@ +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path, PurePath + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if not dir.exists(): + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + # verify log_name exists + fn = Path(dir / log_name) + if not fn.exists(): + print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") + print(f"--> full path of missing log file: {fn}") + return + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame( + np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] + ).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs diff --git a/ee_sim_env.py b/ee_sim_env.py new file mode 100644 index 0000000..a51b0ab --- /dev/null +++ b/ee_sim_env.py @@ -0,0 +1,264 @@ +import numpy as np +import collections +import os + +from constants import DT, XML_DIR, START_ARM_POSE +from constants import PUPPET_GRIPPER_POSITION_CLOSE +from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN +from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN +from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN + +from utils import sample_box_pose, sample_insertion_pose +from dm_control import mujoco +from dm_control.rl import control +from dm_control.suite import base + +import IPython +e = IPython.embed + + +def make_ee_sim_env(task_name): + """ + Environment for simulated robot bi-manual manipulation, with end-effector control. + Action space: [left_arm_pose (7), # position and quaternion for end effector + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_pose (7), # position and quaternion for end effector + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + + Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' + """ + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_{task_name}.xml') + physics = mujoco.Physics.from_xml_path(xml_path) + if task_name == 'transfer_cube': + task = TransferCubeEETask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + elif task_name == 'insertion': + task = InsertionEETask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + else: + raise NotImplementedError + return env + +class BimanualViperXEETask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + a_len = len(action) // 2 + action_left = action[:a_len] + action_right = action[a_len:] + + # set mocap position and quat + # left + np.copyto(physics.data.mocap_pos[0], action_left[:3]) + np.copyto(physics.data.mocap_quat[0], action_left[3:7]) + # right + np.copyto(physics.data.mocap_pos[1], action_right[:3]) + np.copyto(physics.data.mocap_quat[1], action_right[3:7]) + + # set gripper + g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7]) + g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7]) + np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) + + def initialize_robots(self, physics): + # reset joint position + physics.named.data.qpos[:16] = START_ARM_POSE + + # reset mocap to align with end effector + # to obtain these numbers: + # (1) make an ee_sim env and reset to the same start_pose + # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link'] + # get env._physics.named.data.xquat['vx300s_left/gripper_link'] + # repeat the same for right side + np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084]) + np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0]) + # right + np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084])) + np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0]) + + # reset gripper control + close_gripper_control = np.array([ + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + ]) + np.copyto(physics.data.ctrl, close_gripper_control) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])] + right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])] + right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + # note: it is important to do .copy() + obs = collections.OrderedDict() + obs['qpos'] = self.get_qpos(physics) + obs['qvel'] = self.get_qvel(physics) + obs['env_state'] = self.get_env_state(physics) + obs['images'] = dict() + obs['images']['main'] = physics.render(height=480, width=640, camera_id='main') # TODO hardcoded camera name + + # used in scripted policy to obtain starting pose + obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy() + obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy() + + # used when replaying joint trajectory + obs['gripper_ctrl'] = physics.data.ctrl.copy() + return obs + + def get_reward(self, physics): + raise NotImplementedError + + +class TransferCubeEETask(BimanualViperXEETask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize box position + cube_pose = sample_box_pose() + box_start_idx = physics.model.name2id('red_box_joint', 'joint') + np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionEETask(BimanualViperXEETask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize peg and socket position + peg_pose, socket_pose = sample_insertion_pose() + id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky + + peg_start_id = physics.model.name2id('red_peg_joint', 'joint') + peg_start_idx = id2index(peg_start_id) + np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose) + # print(f"randomized cube position to {cube_position}") + + socket_start_id = physics.model.name2id('blue_socket_joint', 'joint') + socket_start_idx = id2index(socket_start_id) + np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ("socket-1", "table") in all_contact_pairs or \ + ("socket-2", "table") in all_contact_pairs or \ + ("socket-3", "table") in all_contact_pairs or \ + ("socket-4", "table") in all_contact_pairs + peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \ + ("red_peg", "socket-2") in all_contact_pairs or \ + ("red_peg", "socket-3") in all_contact_pairs or \ + ("red_peg", "socket-4") in all_contact_pairs + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/imitate_episodes.py b/imitate_episodes.py new file mode 100644 index 0000000..8f13401 --- /dev/null +++ b/imitate_episodes.py @@ -0,0 +1,436 @@ +import torch +import numpy as np +import os +import pickle +import argparse +import matplotlib.pyplot as plt +from copy import deepcopy +from tqdm import tqdm +from einops import rearrange + +from constants import DT, SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, EPISODE_LEN +from constants import PUPPET_GRIPPER_JOINT_OPEN, CAMERA_NAMES, SIM_CAMERA_NAMES +from utils import load_data # data functions +from utils import sample_box_pose, sample_insertion_pose # robot functions +from utils import compute_dict_mean, set_seed, detach_dict # helper functions +from policy import ACTPolicy, CNNMLPPolicy +from visualize_episodes import save_videos + +from sim_env import BOX_POSE + +import IPython +e = IPython.embed + +def main(args): + set_seed(1) + # command line parameters + is_eval = args['eval'] + ckpt_dir = args['ckpt_dir'] + dataset_dir = args['dataset_dir'] + policy_class = args['policy_class'] + onscreen_render = args['onscreen_render'] + task_name = args['task_name'] + batch_size_train = args['batch_size'] + batch_size_val = args['batch_size'] + num_epochs = args['num_epochs'] + + # fixed parameters + num_episodes = 50 + state_dim = 14 + lr_backbone = 1e-5 + backbone = 'resnet18' + if policy_class == 'ACT': + enc_layers = 4 + dec_layers = 7 + nheads = 8 + policy_config = {'lr': args['lr'], + 'num_queries': args['chunk_size'], + 'kl_weight': args['kl_weight'], + 'hidden_dim': args['hidden_dim'], + 'dim_feedforward': args['dim_feedforward'], + 'lr_backbone': lr_backbone, + 'backbone': backbone, + 'enc_layers': enc_layers, + 'dec_layers': dec_layers, + 'nheads': nheads, + } + elif policy_class == 'CNNMLP': + policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1} + else: + raise NotImplementedError + + config = { + 'num_epochs': num_epochs, + 'ckpt_dir': ckpt_dir, + 'state_dim': state_dim, + 'lr': args['lr'], + 'real_robot': 'TBD', + 'policy_class': policy_class, + 'onscreen_render': onscreen_render, + 'policy_config': policy_config, + 'task_name': task_name, + 'seed': args['seed'], + 'temporal_agg': args['temporal_agg'] + } + + train_dataloader, val_dataloader, stats, is_sim = load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val) + + if is_sim: + policy_config['camera_names'] = SIM_CAMERA_NAMES + config['camera_names'] = SIM_CAMERA_NAMES + config['real_robot'] = False + if task_name == 'transfer_cube': + config['episode_len'] = SIM_EPISODE_LEN_TRANSFER_CUBE + elif task_name == 'insertion': + config['episode_len'] = SIM_EPISODE_LEN_INSERTION + else: + policy_config['camera_names'] = CAMERA_NAMES + config['camera_names'] = CAMERA_NAMES + config['real_robot'] = True + config['episode_len'] = EPISODE_LEN + + if is_eval: + ckpt_names = [f'policy_best.ckpt'] + results = [] + for ckpt_name in ckpt_names: + success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True) + results.append([ckpt_name, success_rate, avg_return]) + + for ckpt_name, success_rate, avg_return in results: + print(f'{ckpt_name}: {success_rate=} {avg_return=}') + print() + exit() + + # save dataset stats + if not os.path.isdir(ckpt_dir): + os.makedirs(ckpt_dir) + stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl') + with open(stats_path, 'wb') as f: + pickle.dump(stats, f) + + best_ckpt_info = train_bc(train_dataloader, val_dataloader, config) + best_epoch, min_val_loss, best_state_dict = best_ckpt_info + + # save best checkpoint + ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt') + torch.save(best_state_dict, ckpt_path) + print(f'Best ckpt, val loss {min_val_loss:.6f} @ epoch{best_epoch}') + + +def make_policy(policy_class, policy_config): + if policy_class == 'ACT': + policy = ACTPolicy(policy_config) + elif policy_class == 'CNNMLP': + policy = CNNMLPPolicy(policy_config) + else: + raise NotImplementedError + return policy + + +def make_optimizer(policy_class, policy): + if policy_class == 'ACT': + optimizer = policy.configure_optimizers() + elif policy_class == 'CNNMLP': + optimizer = policy.configure_optimizers() + else: + raise NotImplementedError + return optimizer + + +def get_image(ts, camera_names): + curr_images = [] + for cam_name in camera_names: + curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w') + curr_images.append(curr_image) + curr_image = np.stack(curr_images, axis=0) + curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) + return curr_image + + +def eval_bc(config, ckpt_name, save_episode=True): + set_seed(1000) + ckpt_dir = config['ckpt_dir'] + state_dim = config['state_dim'] + real_robot = config['real_robot'] + policy_class = config['policy_class'] + onscreen_render = config['onscreen_render'] + policy_config = config['policy_config'] + camera_names = config['camera_names'] + max_timesteps = config['episode_len'] + task_name = config['task_name'] + temporal_agg = config['temporal_agg'] + onscreen_cam = 'main' + + # load policy and stats + ckpt_path = os.path.join(ckpt_dir, ckpt_name) + policy = make_policy(policy_class, policy_config) + loading_status = policy.load_state_dict(torch.load(ckpt_path)) + print(loading_status) + policy.cuda() + policy.eval() + print(f'Loaded: {ckpt_path}') + stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl') + with open(stats_path, 'rb') as f: + stats = pickle.load(f) + + pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + post_process = lambda a: a * stats['action_std'] + stats['action_mean'] + + # load environment + if real_robot: + from scripts.utils import move_grippers # requires aloha + from scripts.real_env import make_real_env # requires aloha + env = make_real_env(init_node=True) + env_max_reward = 0 + else: + from sim_env import make_sim_env + env = make_sim_env(task_name) + env_max_reward = env.task.max_reward + + query_frequency = policy_config['num_queries'] + if temporal_agg: + query_frequency = 1 + num_queries = policy_config['num_queries'] + + max_timesteps = int(max_timesteps * 1) # may increase for real-world tasks + + num_rollouts = 50 + episode_returns = [] + highest_rewards = [] + for rollout_id in range(num_rollouts): + rollout_id += 0 + ### set task + if task_name == 'transfer_cube': + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif task_name == 'insertion': + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + else: + raise NotImplementedError + ts = env.reset() + + ### onscreen render + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam)) + plt.ion() + + ### evaluation loop + if temporal_agg: + all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda() + + qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda() + image_list = [] # for visualization + qpos_list = [] + target_qpos_list = [] + rewards = [] + with torch.inference_mode(): + for t in range(max_timesteps): + ### update onscreen render and wait for DT + if onscreen_render: + image = env._physics.render(height=480, width=640, camera_id=onscreen_cam) + plt_img.set_data(image) + plt.pause(DT) + + ### process previous timestep to get qpos and image_list + obs = ts.observation + if 'images' in obs: + image_list.append(obs['images']) + else: + image_list.append({'main': obs['image']}) + qpos_numpy = np.array(obs['qpos']) + qpos = pre_process(qpos_numpy) + qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) + qpos_history[:, t] = qpos + curr_image = get_image(ts, camera_names) + + ### query policy + if config['policy_class'] == "ACT": + if t % query_frequency == 0: + all_actions = policy(qpos, curr_image) + if temporal_agg: + all_time_actions[[t], t:t+num_queries] = all_actions + actions_for_curr_step = all_time_actions[:, t] + actions_populated = torch.all(actions_for_curr_step != 0, axis=1) + actions_for_curr_step = actions_for_curr_step[actions_populated] + k = 0.01 + exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + exp_weights = exp_weights / exp_weights.sum() + exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) + raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + else: + raw_action = all_actions[:, t % query_frequency] + elif config['policy_class'] == "CNNMLP": + raw_action = policy(qpos, curr_image) + else: + raise NotImplementedError + + ### post-process actions + raw_action = raw_action.squeeze(0).cpu().numpy() + action = post_process(raw_action) + target_qpos = action + + ### step the environment + ts = env.step(target_qpos) + + ### for visualization + qpos_list.append(qpos_numpy) + target_qpos_list.append(target_qpos) + rewards.append(ts.reward) + + plt.close() + if real_robot: + move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # open + pass + + rewards = np.array(rewards) + episode_return = np.sum(rewards[rewards!=None]) + episode_returns.append(episode_return) + episode_highest_reward = np.max(rewards) + highest_rewards.append(episode_highest_reward) + print(f'Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, Success: {episode_highest_reward==env_max_reward}') + + if save_episode: + save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4')) + + success_rate = np.mean(np.array(highest_rewards) == env_max_reward) + avg_return = np.mean(episode_returns) + summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n' + for r in range(env_max_reward+1): + more_or_equal_r = (np.array(highest_rewards) >= r).sum() + more_or_equal_r_rate = more_or_equal_r / num_rollouts + summary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n' + + print(summary_str) + + # save success rate to txt + result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt' + with open(os.path.join(ckpt_dir, result_file_name), 'w') as f: + f.write(summary_str) + f.write(repr(episode_returns)) + f.write('\n\n') + f.write(repr(highest_rewards)) + + return success_rate, avg_return + + +def forward_pass(data, policy): + image_data, qpos_data, action_data, is_pad = data + image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda() + return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None + + +def train_bc(train_dataloader, val_dataloader, config): + num_epochs = config['num_epochs'] + ckpt_dir = config['ckpt_dir'] + seed = config['seed'] + policy_class = config['policy_class'] + policy_config = config['policy_config'] + + set_seed(seed) + + policy = make_policy(policy_class, policy_config) + policy.cuda() + optimizer = make_optimizer(policy_class, policy) + + train_history = [] + validation_history = [] + min_val_loss = np.inf + best_ckpt_info = None + for epoch in tqdm(range(num_epochs)): + print(f'\nEpoch {epoch}') + # validation + with torch.inference_mode(): + policy.eval() + epoch_dicts = [] + for batch_idx, data in enumerate(val_dataloader): + forward_dict = forward_pass(data, policy) + epoch_dicts.append(forward_dict) + epoch_summary = compute_dict_mean(epoch_dicts) + validation_history.append(epoch_summary) + + epoch_val_loss = epoch_summary['loss'] + if epoch_val_loss < min_val_loss: + min_val_loss = epoch_val_loss + best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict())) + print(f'Val loss: {epoch_val_loss:.5f}') + summary_string = '' + for k, v in epoch_summary.items(): + summary_string += f'{k}: {v.item():.3f} ' + print(summary_string) + + # training + policy.train() + optimizer.zero_grad() + for batch_idx, data in enumerate(train_dataloader): + forward_dict = forward_pass(data, policy) + # backward + loss = forward_dict['loss'] + loss.backward() + optimizer.step() + optimizer.zero_grad() + train_history.append(detach_dict(forward_dict)) + epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)]) + epoch_train_loss = epoch_summary['loss'] + print(f'Train loss: {epoch_train_loss:.5f}') + summary_string = '' + for k, v in epoch_summary.items(): + summary_string += f'{k}: {v.item():.3f} ' + print(summary_string) + + if epoch % 100 == 0: + ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt') + torch.save(policy.state_dict(), ckpt_path) + plot_history(train_history, validation_history, epoch, ckpt_dir, seed) + + ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt') + torch.save(policy.state_dict(), ckpt_path) + + best_epoch, min_val_loss, best_state_dict = best_ckpt_info + ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt') + torch.save(best_state_dict, ckpt_path) + print(f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}') + + # save training curves + plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed) + + return best_ckpt_info + + +def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): + # save training curves + for key in train_history[0]: + plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png') + plt.figure() + train_values = [summary[key].item() for summary in train_history] + val_values = [summary[key].item() for summary in validation_history] + plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train') + plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation') + # plt.ylim([-0.1, 1]) + plt.tight_layout() + plt.legend() + plt.title(key) + plt.savefig(plot_path) + print(f'Saved plots to {ckpt_dir}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--eval', action='store_true') + parser.add_argument('--onscreen_render', action='store_true') + parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True) + parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) + parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True) + parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) + parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True) + parser.add_argument('--seed', action='store', type=int, help='seed', required=True) + parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) + parser.add_argument('--lr', action='store', type=float, help='lr', required=True) + + # for ACT + parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) + parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) + parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) + parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) + parser.add_argument('--temporal_agg', action='store_true') + + main(vars(parser.parse_args())) diff --git a/policy.py b/policy.py new file mode 100644 index 0000000..7b091e5 --- /dev/null +++ b/policy.py @@ -0,0 +1,84 @@ +import torch.nn as nn +from torch.nn import functional as F +import torchvision.transforms as transforms + +from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer +import IPython +e = IPython.embed + +class ACTPolicy(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_ACT_model_and_optimizer(args_override) + self.model = model # CVAE decoder + self.optimizer = optimizer + self.kl_weight = args_override['kl_weight'] + print(f'KL Weight {self.kl_weight}') + + def __call__(self, qpos, image, actions=None, is_pad=None): + env_state = None + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + image = normalize(image) + if actions is not None: # training time + actions = actions[:, :self.model.num_queries] + is_pad = is_pad[:, :self.model.num_queries] + + a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict = dict() + all_l1 = F.l1_loss(actions, a_hat, reduction='none') + l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean() + loss_dict['l1'] = l1 + loss_dict['kl'] = total_kld[0] + loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight + return loss_dict + else: # inference time + a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior + return a_hat + + def configure_optimizers(self): + return self.optimizer + + +class CNNMLPPolicy(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_CNNMLP_model_and_optimizer(args_override) + self.model = model # decoder + self.optimizer = optimizer + + def __call__(self, qpos, image, actions=None, is_pad=None): + env_state = None # TODO + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + image = normalize(image) + if actions is not None: # training time + actions = actions[:, 0] + a_hat = self.model(qpos, image, env_state, actions) + mse = F.mse_loss(actions, a_hat) + loss_dict = dict() + loss_dict['mse'] = mse + loss_dict['loss'] = loss_dict['mse'] + return loss_dict + else: # inference time + a_hat = self.model(qpos, image, env_state) # no action, sample from prior + return a_hat + + def configure_optimizers(self): + return self.optimizer + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld diff --git a/record_sim_episodes.py b/record_sim_episodes.py new file mode 100644 index 0000000..9bbebe0 --- /dev/null +++ b/record_sim_episodes.py @@ -0,0 +1,187 @@ +import time +import os +import numpy as np +import argparse +import matplotlib.pyplot as plt +import h5py_cache + +from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN +from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, SIM_CAMERA_NAMES +from ee_sim_env import make_ee_sim_env +from sim_env import make_sim_env, BOX_POSE +from scripted_policy import PickAndTransferPolicy, InsertionPolicy + +import IPython +e = IPython.embed + + +def main(args): + """ + Generate demonstration data in simulation. + First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory. + Replace the gripper joint positions with the commanded joint position. + Replay this joint trajectory (as action sequence) in sim_env, and record all observations. + Save this episode of data, and continue to next episode of data collection. + """ + + task_name = args['task_name'] + dataset_dir = args['dataset_dir'] + num_episodes = args['num_episodes'] + onscreen_render = args['onscreen_render'] + inject_noise = False + + if not os.path.isdir(dataset_dir): + os.makedirs(dataset_dir, exist_ok=True) + + if task_name == 'transfer_cube': + policy_cls = PickAndTransferPolicy + episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE + elif task_name == 'insertion': + policy_cls = InsertionPolicy + episode_len = SIM_EPISODE_LEN_INSERTION + else: + raise NotImplementedError + + success = [] + for episode_idx in range(num_episodes): + # setup the environment + env = make_ee_sim_env(task_name) + ts = env.reset() + episode = [ts] + policy = policy_cls(inject_noise) + # setup plotting + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images']['main']) + plt.ion() + for step in range(episode_len): + action = policy(ts) + ts = env.step(action) + episode.append(ts) + if onscreen_render: + plt_img.set_data(ts.observation['images']['main']) + plt.pause(0.002) + plt.close() + + episode_return = np.sum([ts.reward for ts in episode[1:]]) + episode_max_reward = np.max([ts.reward for ts in episode[1:]]) + if episode_max_reward == env.task.max_reward: + print(f"{episode_idx=} Successful, {episode_return=}") + else: + print(f"{episode_idx=} Failed") + + joint_traj = [ts.observation['qpos'] for ts in episode] + # replace gripper pose with gripper control + gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode] + for joint, ctrl in zip(joint_traj, gripper_ctrl_traj): + left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0]) + right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2]) + joint[6] = left_ctrl + joint[6+7] = right_ctrl + + subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0 + + # clear unused variables + del env + del episode + del policy + + # setup the environment + print(f'====== Start Replaying ======') + env = make_sim_env(task_name) + BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env + ts = env.reset() + + episode_replay = [ts] + # setup plotting + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images']['main']) + plt.ion() + for t in range(len(joint_traj)): # note: this will increase episode length by 1 + action = joint_traj[t] + ts = env.step(action) + episode_replay.append(ts) + if onscreen_render: + plt_img.set_data(ts.observation['images']['main']) + plt.pause(0.02) + + episode_return = np.sum([ts.reward for ts in episode_replay[1:]]) + episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]]) + if episode_max_reward == env.task.max_reward: + success.append(1) + print(f"{episode_idx=} Successful, {episode_return=}") + else: + success.append(0) + print(f"{episode_idx=} Failed") + + plt.close() + + """ + For each timestep: + observations + - images + - main (480, 640, 3) 'uint8' + - qpos (14,) 'float64' + - qvel (14,) 'float64' + + action (14,) 'float64' + """ + + data_dict = { + '/observations/qpos': [], + '/observations/qvel': [], + '/action': [], + } + for cam_name in SIM_CAMERA_NAMES: + data_dict[f'/observations/images/{cam_name}'] = [] + + # because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps + # truncate here to be consistent + joint_traj = joint_traj[:-1] + episode_replay = episode_replay[:-1] + + # len(joint_traj) i.e. actions: max_timesteps + # len(episode_replay) i.e. time steps: max_timesteps + 1 + max_timesteps = len(joint_traj) + while joint_traj: + action = joint_traj.pop(0) + ts = episode_replay.pop(0) + data_dict['/observations/qpos'].append(ts.observation['qpos']) + data_dict['/observations/qvel'].append(ts.observation['qvel']) + data_dict['/action'].append(action) + for cam_name in SIM_CAMERA_NAMES: + data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name]) + + # HDF5 + t0 = time.time() + dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}') + with h5py_cache.File(dataset_path + '.hdf5', 'w', chunk_cache_mem_size=1024 ** 2 * 2) as root: + # with h5py.File(dataset_path + '.hdf5', 'w') as root: + root.attrs['sim'] = True + obs = root.create_group('observations') + image = obs.create_group('images') + cam_main = image.create_dataset('main', (max_timesteps, 480, 640, 3), dtype='uint8', + chunks=(1, 480, 640, 3), ) + # compression='gzip',compression_opts=2,) + # compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False) + qpos = obs.create_dataset('qpos', (max_timesteps, 14)) + qvel = obs.create_dataset('qvel', (max_timesteps, 14)) + action = root.create_dataset('action', (max_timesteps, 14)) + + for name, array in data_dict.items(): + root[name][...] = array + print(f'Saving: {time.time() - t0:.1f} secs\n') + + print(f'Saved to {dataset_dir}') + print(f'Success: {np.sum(success)} / {len(success)}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) + parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True) + parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False) + parser.add_argument('--onscreen_render', action='store_true') + + main(vars(parser.parse_args())) + diff --git a/scripted_policy.py b/scripted_policy.py new file mode 100644 index 0000000..dcd612e --- /dev/null +++ b/scripted_policy.py @@ -0,0 +1,195 @@ +import numpy as np +import matplotlib.pyplot as plt +from pyquaternion import Quaternion + +from ee_sim_env import make_ee_sim_env + +import IPython +e = IPython.embed + + +class BasePolicy: + def __init__(self, inject_noise=False): + self.inject_noise = inject_noise + self.step_count = 0 + self.left_trajectory = None + self.right_trajectory = None + + def generate_trajectory(self, ts_first): + raise NotImplementedError + + @staticmethod + def interpolate(curr_waypoint, next_waypoint, t): + t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"]) + curr_xyz = curr_waypoint['xyz'] + curr_quat = curr_waypoint['quat'] + curr_grip = curr_waypoint['gripper'] + next_xyz = next_waypoint['xyz'] + next_quat = next_waypoint['quat'] + next_grip = next_waypoint['gripper'] + xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac + quat = curr_quat + (next_quat - curr_quat) * t_frac + gripper = curr_grip + (next_grip - curr_grip) * t_frac + return xyz, quat, gripper + + def __call__(self, ts): + # generate trajectory at first timestep, then open-loop execution + if self.step_count == 0: + self.generate_trajectory(ts) + + # obtain left and right waypoints + if self.left_trajectory[0]['t'] == self.step_count: + self.curr_left_waypoint = self.left_trajectory.pop(0) + next_left_waypoint = self.left_trajectory[0] + + if self.right_trajectory[0]['t'] == self.step_count: + self.curr_right_waypoint = self.right_trajectory.pop(0) + next_right_waypoint = self.right_trajectory[0] + + # interpolate between waypoints to obtain current pose and gripper command + left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count) + right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count) + + # Inject noise + if self.inject_noise: + scale = 0.01 + left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape) + right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape) + + action_left = np.concatenate([left_xyz, left_quat, [left_gripper]]) + action_right = np.concatenate([right_xyz, right_quat, [right_gripper]]) + + self.step_count += 1 + return np.concatenate([action_left, action_right]) + + +class PickAndTransferPolicy(BasePolicy): + + def generate_trajectory(self, ts_first): + init_mocap_pose_right = ts_first.observation['mocap_pose_right'] + init_mocap_pose_left = ts_first.observation['mocap_pose_left'] + + box_info = np.array(ts_first.observation['env_state']) + box_xyz = box_info[:3] + box_quat = box_info[3:] + # print(f"Generate trajectory for {box_xyz=}") + + gripper_pick_quat = Quaternion(init_mocap_pose_right[3:]) + gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60) + + meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90) + + meet_xyz = np.array([0, 0.5, 0.25]) + + self.left_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep + {"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position + {"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position + {"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper + {"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left + {"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay + ] + + self.right_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep + {"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube + {"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down + {"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper + {"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position + {"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position + {"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper + {"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right + {"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay + ] + + +class InsertionPolicy(BasePolicy): + + def generate_trajectory(self, ts_first): + init_mocap_pose_right = ts_first.observation['mocap_pose_right'] + init_mocap_pose_left = ts_first.observation['mocap_pose_left'] + + peg_info = np.array(ts_first.observation['env_state'])[:7] + peg_xyz = peg_info[:3] + peg_quat = peg_info[3:] + + socket_info = np.array(ts_first.observation['env_state'])[7:] + socket_xyz = socket_info[:3] + socket_quat = socket_info[3:] + + gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:]) + gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60) + + gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:]) + gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60) + + meet_xyz = np.array([0, 0.5, 0.15]) + lift_right = 0.00715 + + self.left_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep + {"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube + {"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down + {"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper + {"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position + {"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion + {"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion + ] + + self.right_trajectory = [ + {"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep + {"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube + {"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down + {"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper + {"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position + {"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion + {"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion + + ] + + +def test_policy(task_name): + # example rolling out pick_and_transfer policy + onscreen_render = True + inject_noise = False + + # setup the environment + from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION + if task_name == 'transfer_cube': + env = make_ee_sim_env('transfer_cube') + episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE + elif task_name == 'insertion': + env = make_ee_sim_env('insertion') + episode_len = SIM_EPISODE_LEN_INSERTION + else: + raise NotImplementedError + + for episode_idx in range(2): + ts = env.reset() + episode = [ts] + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images']['main']) + plt.ion() + + policy = PickAndTransferPolicy(inject_noise) + for step in range(episode_len): + action = policy(ts) + ts = env.step(action) + episode.append(ts) + if onscreen_render: + plt_img.set_data(ts.observation['images']['main']) + plt.pause(0.02) + plt.close() + + episode_return = np.sum([ts.reward for ts in episode[1:]]) + if episode_return > 0: + print(f"{episode_idx=} Successful, {episode_return=}") + else: + print(f"{episode_idx=} Failed") + + +if __name__ == '__main__': + test_task_name = 'transfer_cube' + test_policy(test_task_name) + diff --git a/sim_env.py b/sim_env.py new file mode 100644 index 0000000..55828e7 --- /dev/null +++ b/sim_env.py @@ -0,0 +1,274 @@ +import numpy as np +import os +import collections +import matplotlib.pyplot as plt +from dm_control import mujoco +from dm_control.rl import control +from dm_control.suite import base + +from constants import DT, XML_DIR, START_ARM_POSE, BOX_INIT_POSE +from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN +from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN +from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN +from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN + +import IPython +e = IPython.embed + +BOX_POSE = [None] # to be changed from outside + +def make_sim_env(task_name): + """ + Environment for simulated robot bi-manual manipulation, with joint position control + Action space: [left_arm_qpos (6), # absolute joint position + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + + Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' + """ + xml_path = os.path.join(XML_DIR, f'bimanual_viperx_{task_name}.xml') + physics = mujoco.Physics.from_xml_path(xml_path) + if task_name == 'transfer_cube': + task = TransferCubeTask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + elif task_name == 'insertion': + task = InsertionTask(random=False) + env = control.Environment(physics, task, time_limit=20, control_timestep=DT, + n_sub_steps=None, flat_observation=False) + else: + raise NotImplementedError + return env + +class BimanualViperXTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + left_arm_action = action[:6] + right_arm_action = action[7:7+6] + normalized_left_gripper_action = action[6] + normalized_right_gripper_action = action[7+6] + + left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action) + right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action) + + full_left_gripper_action = [left_gripper_action, -left_gripper_action] + full_right_gripper_action = [right_gripper_action, -right_gripper_action] + + env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]) + super().before_step(env_action, physics) + return + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])] + right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])] + right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + obs = collections.OrderedDict() + obs['qpos'] = self.get_qpos(physics) + obs['qvel'] = self.get_qvel(physics) + obs['env_state'] = self.get_env_state(physics) + obs['images'] = dict() + obs['images']['main'] = physics.render(height=480, width=640, camera_id='top') # TODO hardcoded camera name + obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close') # TODO hardcoded camera name + + return obs + + def get_reward(self, physics): + # return whether left gripper is holding the box + raise NotImplementedError + + +class TransferCubeTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7:] = BOX_POSE[0] + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, 'geom') + name_geom_2 = physics.model.id2name(id_geom_2, 'geom') + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \ + ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ("socket-1", "table") in all_contact_pairs or \ + ("socket-2", "table") in all_contact_pairs or \ + ("socket-3", "table") in all_contact_pairs or \ + ("socket-4", "table") in all_contact_pairs + peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \ + ("red_peg", "socket-2") in all_contact_pairs or \ + ("red_peg", "socket-3") in all_contact_pairs or \ + ("red_peg", "socket-4") in all_contact_pairs + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward + + +def get_action(master_bot_left, master_bot_right): + action = np.zeros(14) + # arm action + action[:6] = master_bot_left.dxl.joint_states.position[:6] + action[7:7+6] = master_bot_right.dxl.joint_states.position[:6] + # gripper action + left_gripper_pos = master_bot_left.dxl.joint_states.position[7] + right_gripper_pos = master_bot_right.dxl.joint_states.position[7] + normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos) + normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos) + action[6] = normalized_left_pos + action[7+6] = normalized_right_pos + return action + +def test_sim_teleop(): + from interbotix_xs_modules.arm import InterbotixManipulatorXS + + BOX_POSE[0] = BOX_INIT_POSE + + # source of data + master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_left', init_node=True) + master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_right', init_node=False) + + # setup the environment + env = make_sim_env() + ts = env.reset() + episode = [ts] + # setup plotting + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['image']) + plt.ion() + + for t in range(1000): + action = get_action(master_bot_left, master_bot_right) + ts = env.step(action) + episode.append(ts) + + plt_img.set_data(ts.observation['image']) + plt.pause(0.02) + + +if __name__ == '__main__': + test_sim_teleop() + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..d3851b5 --- /dev/null +++ b/utils.py @@ -0,0 +1,192 @@ +import numpy as np +import torch +import os +import h5py +from torch.utils.data import TensorDataset, DataLoader +from constants import SIM_CAMERA_NAMES, CAMERA_NAMES + +import IPython +e = IPython.embed + +class EpisodicDataset(torch.utils.data.Dataset): + def __init__(self, episode_ids, dataset_dir, norm_stats): + super(EpisodicDataset).__init__() + self.episode_ids = episode_ids + self.dataset_dir = dataset_dir + self.norm_stats = norm_stats + self.is_sim = None + self.__getitem__(0) # initialize self.is_sim + + def __len__(self): + return len(self.episode_ids) + + def __getitem__(self, index): + sample_full_episode = False # hardcode + + episode_id = self.episode_ids[index] + dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') + with h5py.File(dataset_path, 'r') as root: + is_sim = root.attrs['sim'] + if is_sim: + camera_names = SIM_CAMERA_NAMES + else: + camera_names = CAMERA_NAMES + original_action_shape = root['/action'].shape + episode_len = original_action_shape[0] + if sample_full_episode: + start_ts = 0 + else: + start_ts = np.random.choice(episode_len) + # get observation at start_ts only + qpos = root['/observations/qpos'][start_ts] + qvel = root['/observations/qvel'][start_ts] + image_dict = dict() + for cam_name in camera_names: + image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts] + # get all actions after and including start_ts + if is_sim: + action = root['/action'][start_ts:] + action_len = episode_len - start_ts + else: + action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned + action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned + + self.is_sim = is_sim + padded_action = np.zeros(original_action_shape, dtype=np.float32) + padded_action[:action_len] = action + is_pad = np.zeros(episode_len) + is_pad[action_len:] = 1 + + # new axis for different cameras + all_cam_images = [] + for cam_name in camera_names: + all_cam_images.append(image_dict[cam_name]) + all_cam_images = np.stack(all_cam_images, axis=0) + + # construct observations + image_data = torch.from_numpy(all_cam_images) + qpos_data = torch.from_numpy(qpos).float() + action_data = torch.from_numpy(padded_action).float() + is_pad = torch.from_numpy(is_pad).bool() + + # channel last + image_data = torch.einsum('k h w c -> k c h w', image_data) + + # normalize image and change dtype to float + image_data = image_data / 255.0 + action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] + qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"] + + return image_data, qpos_data, action_data, is_pad + + +def get_norm_stats(dataset_dir, num_episodes): + all_qpos_data = [] + all_action_data = [] + for episode_idx in range(num_episodes): + dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5') + with h5py.File(dataset_path, 'r') as root: + qpos = root['/observations/qpos'][()] + qvel = root['/observations/qvel'][()] + action = root['/action'][()] + all_qpos_data.append(torch.from_numpy(qpos)) + all_action_data.append(torch.from_numpy(action)) + all_qpos_data = torch.stack(all_qpos_data) + all_action_data = torch.stack(all_action_data) + all_action_data = all_action_data + + # normalize action data + action_mean = all_action_data.mean(dim=[0, 1], keepdim=True) + action_std = all_action_data.std(dim=[0, 1], keepdim=True) + action_std = torch.clip(action_std, 1e-2, 10) # clipping + + # normalize qpos data + qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True) + qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True) + qpos_std = torch.clip(qpos_std, 1e-2, 10) # clipping + + stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(), + "qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(), + "example_qpos": qpos} + + return stats + + +def load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val): + # obtain train test split + train_ratio = 0.8 # TODO + shuffled_indices = np.random.permutation(num_episodes) + train_indices = shuffled_indices[:int(train_ratio * num_episodes)] + val_indices = shuffled_indices[int(train_ratio * num_episodes):] + + # obtain normalization stats for qpos and action + norm_stats = get_norm_stats(dataset_dir, num_episodes) + + # construct dataset and dataloader + train_dataset = EpisodicDataset(train_indices, dataset_dir, norm_stats) + val_dataset = EpisodicDataset(val_indices, dataset_dir, norm_stats) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) + + return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim + + +### env utils + +def sample_box_pose(): + x_range = [0.0, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + cube_quat = np.array([1, 0, 0, 0]) + return np.concatenate([cube_position, cube_quat]) + +def sample_insertion_pose(): + # Peg + x_range = [0.1, 0.2] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + peg_quat = np.array([1, 0, 0, 0]) + peg_pose = np.concatenate([peg_position, peg_quat]) + + # Socket + x_range = [-0.2, -0.1] + y_range = [0.4, 0.6] + z_range = [0.05, 0.05] + + ranges = np.vstack([x_range, y_range, z_range]) + socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) + + socket_quat = np.array([1, 0, 0, 0]) + socket_pose = np.concatenate([socket_position, socket_quat]) + + return peg_pose, socket_pose + +### helper functions + +def compute_dict_mean(epoch_dicts): + result = {k: None for k in epoch_dicts[0]} + num_items = len(epoch_dicts) + for k in result: + value_sum = 0 + for epoch_dict in epoch_dicts: + value_sum += epoch_dict[k] + result[k] = value_sum / num_items + return result + +def detach_dict(d): + new_d = dict() + for k, v in d.items(): + new_d[k] = v.detach() + return new_d + +def set_seed(seed): + torch.manual_seed(seed) + np.random.seed(seed) diff --git a/visualize_episodes.py b/visualize_episodes.py new file mode 100644 index 0000000..9fb315e --- /dev/null +++ b/visualize_episodes.py @@ -0,0 +1,148 @@ +import os +import numpy as np +import cv2 +import h5py +import argparse + +import matplotlib.pyplot as plt +from constants import DT, CAMERA_NAMES, SIM_CAMERA_NAMES + +import IPython +e = IPython.embed + +JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] +STATE_NAMES = JOINT_NAMES + ["gripper"] + +def load_hdf5(dataset_dir, dataset_name): + dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5') + if not os.path.isfile(dataset_path): + print(f'Dataset does not exist at \n{dataset_path}\n') + exit() + + with h5py.File(dataset_path, 'r') as root: + is_sim = root.attrs['sim'] + qpos = root['/observations/qpos'][()] + qvel = root['/observations/qvel'][()] + action = root['/action'][()] + image_dict = dict() + camera_names = SIM_CAMERA_NAMES if is_sim else CAMERA_NAMES + for cam_name in camera_names: + image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] + + return qpos, qvel, action, image_dict + +def main(args): + dataset_dir = args['dataset_dir'] + episode_idx = args['episode_idx'] + dataset_name = f'episode_{episode_idx}' + + qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name) + save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4')) + visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png')) + # visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back + + +def save_videos(video, dt, video_path=None): + if isinstance(video, list): + cam_names = list(video[0].keys()) + h, w, _ = video[0][cam_names[0]].shape + w = w * len(cam_names) + fps = int(1/dt) + out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for ts, image_dict in enumerate(video): + images = [] + for cam_name in cam_names: + image = image_dict[cam_name] + image = image[:, :, [2, 1, 0]] # swap B and R channel + images.append(image) + images = np.concatenate(images, axis=1) + out.write(images) + out.release() + print(f'Saved video to: {video_path}') + elif isinstance(video, dict): + cam_names = list(video.keys()) + all_cam_videos = [] + for cam_name in cam_names: + all_cam_videos.append(video[cam_name]) + all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension + + n_frames, h, w, _ = all_cam_videos.shape + fps = int(1 / dt) + out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for t in range(n_frames): + image = all_cam_videos[t] + image = image[:, :, [2, 1, 0]] # swap B and R channel + out.write(image) + out.release() + print(f'Saved video to: {video_path}') + + +def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None): + if label_overwrite: + label1, label2 = label_overwrite + else: + label1, label2 = 'State', 'Command' + + qpos = np.array(qpos_list) # ts, dim + command = np.array(command_list) + num_ts, num_dim = qpos.shape + h, w = 2, num_dim + num_figs = num_dim + fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs)) + + # plot joint state + all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES] + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(qpos[:, dim_idx], label=label1) + ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}') + ax.legend() + + # plot arm command + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(command[:, dim_idx], label=label2) + ax.legend() + + if ylim: + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.set_ylim(ylim) + + plt.tight_layout() + plt.savefig(plot_path) + print(f'Saved qpos plot to: {plot_path}') + plt.close() + +def visualize_timestamp(t_list, dataset_path): + plot_path = dataset_path.replace('.pkl', '_timestamp.png') + h, w = 4, 10 + fig, axs = plt.subplots(2, 1, figsize=(w, h*2)) + # process t_list + t_float = [] + for secs, nsecs in t_list: + t_float.append(secs + nsecs * 10E-10) + t_float = np.array(t_float) + + ax = axs[0] + ax.plot(np.arange(len(t_float)), t_float) + ax.set_title(f'Camera frame timestamps') + ax.set_xlabel('timestep') + ax.set_ylabel('time (sec)') + + ax = axs[1] + ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:]) + ax.set_title(f'dt') + ax.set_xlabel('timestep') + ax.set_ylabel('time (sec)') + + plt.tight_layout() + plt.savefig(plot_path) + print(f'Saved timestamp plot to: {plot_path}') + plt.close() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True) + parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False) + main(vars(parser.parse_args()))