Skip to content

Commit

Permalink
change it for the minitaur env
Browse files Browse the repository at this point in the history
  • Loading branch information
jietan committed Mar 27, 2018
1 parent 36b5e93 commit f781beb
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
16 changes: 8 additions & 8 deletions code/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Aurelia Guy
Benjamin Recht
'''

import parser
import time
import os
Expand All @@ -17,6 +16,8 @@
from policies import *
import socket
from shared_noise import *
import pybullet
from pybullet_envs.bullet import minitaur_gym_env

@ray.remote
class Worker(object):
Expand All @@ -32,9 +33,9 @@ def __init__(self, env_seed,
delta_std=0.02):

# initialize OpenAI environment for each worker
self.env = gym.make(env_name)
self.env = minitaur_gym_env.MinitaurBulletEnv()#gym.make(env_name)
self.env.seed(env_seed)

print ("env seed: {}".format(env_seed))
# each worker gets access to the shared noise table
# with independent random streams for sampling
# from the shared noise table.
Expand Down Expand Up @@ -99,7 +100,7 @@ def do_rollouts(self, w_policy, num_rollouts = 1, shift = 1, evaluate = False):

# for evaluation we do not shift the rewards (shift = 0) and we use the
# default rollout length (1000 for the MuJoCo locomotion tasks)
reward, r_steps = self.rollout(shift = 0., rollout_length = self.env.spec.timestep_limit)
reward, r_steps = self.rollout(shift = 0., rollout_length = 1000)
rollout_rewards.append(reward)

else:
Expand Down Expand Up @@ -160,7 +161,7 @@ def __init__(self, env_name='HalfCheetah-v1',
logz.configure_output_dir(logdir)
logz.save_params(params)

env = gym.make(env_name)
env = minitaur_gym_env.MinitaurBulletEnv() #gym.make(env_name)

self.timesteps = 0
self.action_size = env.action_space.shape[0]
Expand All @@ -180,6 +181,7 @@ def __init__(self, env_name='HalfCheetah-v1',
# create shared table for storing noise
print("Creating deltas table.")
deltas_id = create_shared_noise.remote()

self.deltas = SharedNoiseTable(ray.get(deltas_id), seed = seed + 3)
print('Created deltas table.')

Expand All @@ -193,7 +195,6 @@ def __init__(self, env_name='HalfCheetah-v1',
rollout_length=rollout_length,
delta_std=delta_std) for i in range(num_workers)]


# initialize policy
if policy_params['type'] == 'linear':
self.policy = LinearPolicy(policy_params)
Expand All @@ -209,7 +210,6 @@ def aggregate_rollouts(self, num_rollouts = None, evaluate = False):
"""
Aggregate update step from rollouts generated in parallel.
"""

if num_rollouts is None:
num_deltas = self.num_deltas
else:
Expand Down Expand Up @@ -355,7 +355,7 @@ def run_ars(params):
if not(os.path.exists(logdir)):
os.makedirs(logdir)

env = gym.make(params['env_name'])
env = minitaur_gym_env.MinitaurBulletEnv() #gym.make(params['env_name'])
ob_dim = env.observation_space.shape[0]
ac_dim = env.action_space.shape[0]

Expand Down
3 changes: 2 additions & 1 deletion code/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self, policy_params):

def act(self, ob):
ob = self.observation_filter(ob, update=self.update_filter)
return np.dot(self.weights, ob)

return np.clip(np.dot(self.weights, ob), -1.0, 1.0)

def get_weights_plus_stats(self):

Expand Down
21 changes: 12 additions & 9 deletions code/run_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
"""
import numpy as np
import gym
import pybullet
from pybullet_envs.bullet import minitaur_gym_env

def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('expert_policy_file', type=str)
parser.add_argument('envname', type=str)
parser.add_argument('--render', action='store_true')
parser.add_argument('--num_rollouts', type=int, default=20,
parser.add_argument('--num_rollouts', type=int, default=1,
help='Number of expert rollouts')
args = parser.parse_args()

Expand All @@ -26,32 +27,34 @@ def main():
# mean and std of state vectors estimated online by ARS.
mean = lin_policy[1]
std = lin_policy[2]

env = gym.make(args.envname)
env = minitaur_gym_env.MinitaurBulletEnv(render=True)#gym.make(env_name)
# env = gym.make(args.envname)

returns = []
observations = []
actions = []
for i in range(args.num_rollouts):

print('iter', i)
obs = env.reset()
log_id = pybullet.startStateLogging(pybullet.STATE_LOGGING_VIDEO_MP4, "/usr/local/google/home/jietan/Projects/ARS/data/minitaur{}.mp4".format(i))
done = False
totalr = 0.
steps = 0
while not done:
action = np.dot(M, (obs - mean)/std)
action = np.clip(np.dot(M, (obs - mean)/std), -1.0, 1.0)
observations.append(obs)
actions.append(action)



obs, r, done, _ = env.step(action)
totalr += r
steps += 1
if args.render:
env.render()
if steps % 100 == 0: print("%i/%i"%(steps, env.spec.timestep_limit))
if steps >= env.spec.timestep_limit:
if steps % 100 == 0: print("%i/%i"%(steps, 1000))
if steps >= 1000:
break
pybullet.stopStateLogging(log_id)
returns.append(totalr)

print('returns', returns)
Expand Down
Binary file added trained_policies/Minitaur/lin_policy_plus.npz
Binary file not shown.

0 comments on commit f781beb

Please sign in to comment.