-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathconfig.py
129 lines (121 loc) · 4.66 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import jax
import jax.numpy as jnp
import numpy as np
from franka_env.envs.wrappers import (
Quat2EulerWrapper,
SpacemouseIntervention,
MultiCameraBinaryRewardClassifierWrapper,
GripperCloseEnv
)
from franka_env.envs.relative_env import RelativeFrame
from franka_env.envs.franka_env import DefaultEnvConfig
from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper
from serl_launcher.wrappers.chunking import ChunkingWrapper
from serl_launcher.networks.reward_classifier import load_classifier_func
from experiments.config import DefaultTrainingConfig
from experiments.ram_insertion.wrapper import RAMEnv
class EnvConfig(DefaultEnvConfig):
SERVER_URL = "http://127.0.0.2:5000/"
REALSENSE_CAMERAS = {
"wrist_1": {
"serial_number": "127122270146",
"dim": (1280, 720),
"exposure": 40000,
},
"wrist_2": {
"serial_number": "127122270350",
"dim": (1280, 720),
"exposure": 40000,
},
}
IMAGE_CROP = {
"wrist_1": lambda img: img[150:450, 350:1100],
"wrist_2": lambda img: img[100:500, 400:900],
}
TARGET_POSE = np.array([0.5881241235410154,-0.03578590131997776,0.27843494179085326, np.pi, 0, 0])
GRASP_POSE = np.array([0.5857508505445138,-0.22036261105675414,0.2731021902359492, np.pi, 0, 0])
RESET_POSE = TARGET_POSE + np.array([0, 0, 0.05, 0, 0.05, 0])
ABS_POSE_LIMIT_LOW = TARGET_POSE - np.array([0.03, 0.02, 0.01, 0.01, 0.1, 0.4])
ABS_POSE_LIMIT_HIGH = TARGET_POSE + np.array([0.03, 0.02, 0.05, 0.01, 0.1, 0.4])
RANDOM_RESET = True
RANDOM_XY_RANGE = 0.02
RANDOM_RZ_RANGE = 0.05
ACTION_SCALE = (0.01, 0.06, 1)
DISPLAY_IMAGE = True
MAX_EPISODE_LENGTH = 100
COMPLIANCE_PARAM = {
"translational_stiffness": 2000,
"translational_damping": 89,
"rotational_stiffness": 150,
"rotational_damping": 7,
"translational_Ki": 0,
"translational_clip_x": 0.0075,
"translational_clip_y": 0.0016,
"translational_clip_z": 0.0055,
"translational_clip_neg_x": 0.002,
"translational_clip_neg_y": 0.0016,
"translational_clip_neg_z": 0.005,
"rotational_clip_x": 0.01,
"rotational_clip_y": 0.025,
"rotational_clip_z": 0.005,
"rotational_clip_neg_x": 0.01,
"rotational_clip_neg_y": 0.025,
"rotational_clip_neg_z": 0.005,
"rotational_Ki": 0,
}
PRECISION_PARAM = {
"translational_stiffness": 2000,
"translational_damping": 89,
"rotational_stiffness": 250,
"rotational_damping": 9,
"translational_Ki": 0.0,
"translational_clip_x": 0.1,
"translational_clip_y": 0.1,
"translational_clip_z": 0.1,
"translational_clip_neg_x": 0.1,
"translational_clip_neg_y": 0.1,
"translational_clip_neg_z": 0.1,
"rotational_clip_x": 0.5,
"rotational_clip_y": 0.5,
"rotational_clip_z": 0.5,
"rotational_clip_neg_x": 0.5,
"rotational_clip_neg_y": 0.5,
"rotational_clip_neg_z": 0.5,
"rotational_Ki": 0.0,
}
class TrainConfig(DefaultTrainingConfig):
image_keys = ["wrist_1", "wrist_2"]
classifier_keys = ["wrist_1", "wrist_2"]
proprio_keys = ["tcp_pose", "tcp_vel", "tcp_force", "tcp_torque", "gripper_pose"]
buffer_period = 1000
checkpoint_period = 5000
steps_per_update = 50
encoder_type = "resnet-pretrained"
setup_mode = "single-arm-fixed-gripper"
def get_environment(self, fake_env=False, save_video=False, classifier=False):
env = RAMEnv(
fake_env=fake_env,
save_video=save_video,
config=EnvConfig(),
)
env = GripperCloseEnv(env)
if not fake_env:
env = SpacemouseIntervention(env)
env = RelativeFrame(env)
env = Quat2EulerWrapper(env)
env = SERLObsWrapper(env, proprio_keys=self.proprio_keys)
env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None)
if classifier:
classifier = load_classifier_func(
key=jax.random.PRNGKey(0),
sample=env.observation_space.sample(),
image_keys=self.classifier_keys,
checkpoint_path=os.path.abspath("classifier_ckpt/"),
)
def reward_func(obs):
sigmoid = lambda x: 1 / (1 + jnp.exp(-x))
# added check for z position to further robustify classifier, but should work without as well
return int(sigmoid(classifier(obs)) > 0.85 and obs['state'][0, 6] > 0.04)
env = MultiCameraBinaryRewardClassifierWrapper(env, reward_func)
return env