Skip to content

Commit

Permalink
FEATURE: reset now accepts an optional seed. A few env suites upgrade…
Browse files Browse the repository at this point in the history
…d. Others are WIP and will come with next few commits
  • Loading branch information
vikashplus committed Dec 31, 2023
1 parent ae51c9e commit 676375f
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
6 changes: 3 additions & 3 deletions robohive/envs/arms/pick_place_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
================================================= """

import collections
from robohive.utils.import_utils import gym
from robohive.utils import gym
import numpy as np

from robohive.envs import env_base
Expand Down Expand Up @@ -115,7 +115,7 @@ def get_reward_dict(self, obs_dict):
rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0)
return rwd_dict

def reset(self):
def reset(self, **kwargs):

if self.randomize:
# target location
Expand All @@ -137,7 +137,7 @@ def reset(self):
self.sim.model.geom_rgba[gid]=self.np_random.uniform(low=[.2, .2, .2, 1], high=[.9, .9, .9, 1]) # random color
self.sim.forward()

obs = super().reset(self.init_qpos, self.init_qvel)
obs = super().reset(self.init_qpos, self.init_qvel, **kwargs)
return obs

# def viewer_setup(self):
Expand Down
6 changes: 3 additions & 3 deletions robohive/envs/arms/push_base_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
================================================= """

import collections
from robohive.utils.import_utils import gym
from robohive.utils import gym
import numpy as np

from robohive.envs import env_base
Expand Down Expand Up @@ -103,8 +103,8 @@ def get_reward_dict(self, obs_dict):
rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0)
return rwd_dict

def reset(self):
def reset(self, **kwargs):
self.sim.model.site_pos[self.target_sid] = self.np_random.uniform(high=self.target_xyz_range['high'], low=self.target_xyz_range['low'])
self.sim_obsd.model.site_pos[self.target_sid] = self.sim.model.site_pos[self.target_sid]
obs = super().reset(self.init_qpos, self.init_qvel)
obs = super().reset(self.init_qpos, self.init_qvel, **kwargs)
return obs
6 changes: 3 additions & 3 deletions robohive/envs/arms/reach_base_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
================================================= """

import collections
from robohive.utils.import_utils import gym
from robohive.utils import gym
import numpy as np

from robohive.envs import env_base
Expand Down Expand Up @@ -97,8 +97,8 @@ def get_reward_dict(self, obs_dict):
rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0)
return rwd_dict

def reset(self, reset_qpos=None, reset_qvel=None):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
self.sim.model.site_pos[self.target_sid] = self.np_random.uniform(high=self.target_xyz_range['high'], low=self.target_xyz_range['low'])
self.sim_obsd.model.site_pos[self.target_sid] = self.sim.model.site_pos[self.target_sid]
obs = super().reset(reset_qpos, reset_qvel)
obs = super().reset(reset_qpos, reset_qvel, **kwargs)
return obs
10 changes: 5 additions & 5 deletions robohive/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
================================================= """

from robohive.utils.import_utils import gym
from robohive.utils import gym
import numpy as np
import os
import time as timer
Expand Down Expand Up @@ -492,14 +492,14 @@ def get_input_seed(self):
return self.input_seed


def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def _reset(self, reset_qpos=None, reset_qvel=None, seed=None, **kwargs):
"""
Reset the environment
Default implemention provided. Override if env needs custom reset
"""
qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos
qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel
self.robot.reset(qpos, qvel, **kwargs)
self.robot.reset(reset_pos=qpos, reset_vel=qvel, seed=seed, **kwargs)
return self.get_obs()
@implement_for("gym", None, "0.26")
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
Expand All @@ -508,8 +508,8 @@ def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}
@implement_for("gymnasium")
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}
def reset(self, reset_qpos=None, reset_qvel=None, seed=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, seed=seed, **kwargs), {}

# @property
# def _step(self, a):
Expand Down
5 changes: 3 additions & 2 deletions robohive/robot/robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,8 @@ def step(self, ctrl_desired, step_duration, ctrl_normalized=True, realTimeSim=Fa
def reset(self,
reset_pos,
reset_vel,
blocking = True
blocking = True,
**kwargs
):

prompt("Resetting {}".format(self.name), 'white', 'on_grey', flush=True)
Expand Down Expand Up @@ -793,7 +794,7 @@ def __del__(self):


def demo_robot():
from robohive.utils.import_utils import gym
from robohive.utils import gym

prompt("Starting Robot===================")
env = gym.make('FrankaReachFixed-v0')
Expand Down

0 comments on commit 676375f

Please sign in to comment.