-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #89 from salesforce/classic_control
Classic control
- Loading branch information
Showing
13 changed files
with
423 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
|
||
|
||
class SingleAgentEnv: | ||
|
||
def __init__(self, episode_length=500, env_backend="cpu", reset_pool_size=0, seed=None): | ||
""" | ||
:param episode_length: | ||
:param env_backend: "cpu" or "numba" ("pycuda" is not supported for SingleAgentEnv) | ||
:param reset_pool_size: if reset_pool_size < 2, we assume the reset is using a default fixed one for all envs | ||
""" | ||
self.num_agents = 1 | ||
|
||
self.agents = {} | ||
for agent_id in range(self.num_agents): | ||
self.agents[agent_id] = True | ||
|
||
assert episode_length > 0 | ||
self.episode_length = episode_length | ||
|
||
self.action_space = None | ||
self.observation_space = None | ||
self.timestep = None | ||
|
||
self.env_backend = env_backend | ||
self.reset_pool_size = reset_pool_size | ||
|
||
# Seeding | ||
self.seed = seed | ||
|
||
|
||
def map_to_single_agent(val): | ||
return {0: val} | ||
|
||
|
||
def get_action_for_single_agent(action): | ||
assert isinstance(action, dict) | ||
assert len(action) == 1 | ||
return action[0] |
Empty file.
Empty file.
129 changes: 129 additions & 0 deletions
129
example_envs/single_agent/classic_control/cartpole/cartpole.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import numpy as np | ||
from warp_drive.utils.constants import Constants | ||
from warp_drive.utils.data_feed import DataFeed | ||
from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext | ||
|
||
from example_envs.single_agent.base import SingleAgentEnv, map_to_single_agent, get_action_for_single_agent | ||
from gym.envs.classic_control.cartpole import CartPoleEnv | ||
|
||
_OBSERVATIONS = Constants.OBSERVATIONS | ||
_ACTIONS = Constants.ACTIONS | ||
_REWARDS = Constants.REWARDS | ||
|
||
|
||
class ClassicControlCartPoleEnv(SingleAgentEnv): | ||
|
||
name = "ClassicControlCartPoleEnv" | ||
|
||
def __init__(self, episode_length, env_backend="cpu", reset_pool_size=0, seed=None): | ||
super().__init__(episode_length, env_backend, reset_pool_size, seed=seed) | ||
|
||
self.gym_env = CartPoleEnv() | ||
|
||
self.action_space = map_to_single_agent(self.gym_env.action_space) | ||
self.observation_space = map_to_single_agent(self.gym_env.observation_space) | ||
|
||
def step(self, action=None): | ||
self.timestep += 1 | ||
action = get_action_for_single_agent(action) | ||
state, reward, terminated, _, _ = self.gym_env.step(action) | ||
|
||
obs = map_to_single_agent(state) | ||
rew = map_to_single_agent(reward) | ||
done = {"__all__": self.timestep >= self.episode_length or terminated} | ||
info = {} | ||
|
||
return obs, rew, done, info | ||
|
||
def reset(self): | ||
self.timestep = 0 | ||
if self.reset_pool_size < 2: | ||
# we use a fixed initial state all the time | ||
initial_state, _ = self.gym_env.reset(seed=self.seed) | ||
else: | ||
initial_state, _ = self.gym_env.reset(seed=None) | ||
obs = map_to_single_agent(initial_state) | ||
|
||
return obs | ||
|
||
|
||
class CUDAClassicControlCartPoleEnv(ClassicControlCartPoleEnv, CUDAEnvironmentContext): | ||
|
||
def get_data_dictionary(self): | ||
data_dict = DataFeed() | ||
initial_state, _ = self.gym_env.reset(seed=self.seed) | ||
|
||
if self.reset_pool_size < 2: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=True, | ||
) | ||
else: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=False, | ||
) | ||
|
||
data_dict.add_data_list( | ||
[ | ||
("gravity", self.gym_env.gravity), | ||
("masspole", self.gym_env.masspole), | ||
("total_mass", self.gym_env.masspole + self.gym_env.masscart), | ||
("length", self.gym_env.length), | ||
("polemass_length", self.gym_env.masspole * self.gym_env.length), | ||
("force_mag", self.gym_env.force_mag), | ||
("tau", self.gym_env.tau), | ||
("theta_threshold_radians", self.gym_env.theta_threshold_radians), | ||
("x_threshold", self.gym_env.x_threshold), | ||
] | ||
) | ||
return data_dict | ||
|
||
def get_tensor_dictionary(self): | ||
tensor_dict = DataFeed() | ||
return tensor_dict | ||
|
||
def get_reset_pool_dictionary(self): | ||
reset_pool_dict = DataFeed() | ||
if self.reset_pool_size >= 2: | ||
state_reset_pool = [] | ||
for _ in range(self.reset_pool_size): | ||
initial_state, _ = self.gym_env.reset(seed=None) | ||
state_reset_pool.append(np.atleast_2d(initial_state)) | ||
state_reset_pool = np.stack(state_reset_pool, axis=0) | ||
assert len(state_reset_pool.shape) == 3 and state_reset_pool.shape[2] == 4 | ||
|
||
reset_pool_dict.add_pool_for_reset(name="state_reset_pool", | ||
data=state_reset_pool, | ||
reset_target="state") | ||
return reset_pool_dict | ||
|
||
def step(self, actions=None): | ||
self.timestep += 1 | ||
args = [ | ||
"state", | ||
_ACTIONS, | ||
"_done_", | ||
_REWARDS, | ||
_OBSERVATIONS, | ||
"gravity", | ||
"masspole", | ||
"total_mass", | ||
"length", | ||
"polemass_length", | ||
"force_mag", | ||
"tau", | ||
"theta_threshold_radians", | ||
"x_threshold", | ||
"_timestep_", | ||
("episode_length", "meta"), | ||
] | ||
if self.env_backend == "numba": | ||
self.cuda_step[ | ||
self.cuda_function_manager.grid, self.cuda_function_manager.block | ||
](*self.cuda_step_function_feed(args)) | ||
else: | ||
raise Exception("CUDAClassicControlCartPoleEnv expects env_backend = 'numba' ") | ||
|
83 changes: 83 additions & 0 deletions
83
example_envs/single_agent/classic_control/cartpole/cartpole_step_numba.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import numba.cuda as numba_driver | ||
import math | ||
|
||
|
||
@numba_driver.jit | ||
def NumbaClassicControlCartPoleEnvStep( | ||
state_arr, | ||
action_arr, | ||
done_arr, | ||
reward_arr, | ||
observation_arr, | ||
gravity, | ||
masspole, | ||
total_mass, | ||
length, | ||
polemass_length, | ||
force_mag, | ||
tau, | ||
theta_threshold_radians, | ||
x_threshold, | ||
env_timestep_arr, | ||
episode_length): | ||
|
||
kEnvId = numba_driver.blockIdx.x | ||
kThisAgentId = numba_driver.threadIdx.x | ||
|
||
assert kThisAgentId == 0, "We only have one agent per environment" | ||
|
||
env_timestep_arr[kEnvId] += 1 | ||
|
||
assert 0 < env_timestep_arr[kEnvId] <= episode_length | ||
|
||
reward_arr[kEnvId, kThisAgentId] = 0.0 | ||
|
||
action = action_arr[kEnvId, kThisAgentId, 0] | ||
|
||
x = state_arr[kEnvId, kThisAgentId, 0] | ||
x_dot = state_arr[kEnvId, kThisAgentId, 1] | ||
theta = state_arr[kEnvId, kThisAgentId, 2] | ||
theta_dot = state_arr[kEnvId, kThisAgentId, 3] | ||
|
||
if action > 0.5: | ||
force = force_mag | ||
else: | ||
force = -force_mag | ||
|
||
costheta = math.cos(theta) | ||
sintheta = math.sin(theta) | ||
|
||
temp = (force + polemass_length * theta_dot ** 2 * sintheta) / total_mass | ||
thetaacc = (gravity * sintheta - costheta * temp) / ( | ||
length * (4.0 / 3.0 - masspole * costheta ** 2 / total_mass) | ||
) | ||
xacc = temp - polemass_length * thetaacc * costheta / total_mass | ||
|
||
# we use kinematics_integrator == "euler", same as that in the original gym code | ||
x = x + tau * x_dot | ||
x_dot = x_dot + tau * xacc | ||
theta = theta + tau * theta_dot | ||
theta_dot = theta_dot + tau * thetaacc | ||
|
||
state_arr[kEnvId, kThisAgentId, 0] = x | ||
state_arr[kEnvId, kThisAgentId, 1] = x_dot | ||
state_arr[kEnvId, kThisAgentId, 2] = theta | ||
state_arr[kEnvId, kThisAgentId, 3] = theta_dot | ||
|
||
observation_arr[kEnvId, kThisAgentId, 0] = state_arr[kEnvId, kThisAgentId, 0] | ||
observation_arr[kEnvId, kThisAgentId, 1] = state_arr[kEnvId, kThisAgentId, 1] | ||
observation_arr[kEnvId, kThisAgentId, 2] = state_arr[kEnvId, kThisAgentId, 2] | ||
observation_arr[kEnvId, kThisAgentId, 3] = state_arr[kEnvId, kThisAgentId, 3] | ||
|
||
terminated = bool( | ||
x < -x_threshold | ||
or x > x_threshold | ||
or theta < -theta_threshold_radians | ||
or theta > theta_threshold_radians | ||
) | ||
|
||
# as long as not reset, we assign reward 1. This is consistent with original cartpole logic | ||
reward_arr[kEnvId, kThisAgentId] = 1.0 | ||
|
||
if env_timestep_arr[kEnvId] == episode_length or terminated: | ||
done_arr[kEnvId] = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
tests/example_envs/numba_tests/single_agent/classic_control/test_cartpole.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import unittest | ||
import numpy as np | ||
import torch | ||
|
||
from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU | ||
from example_envs.single_agent.classic_control.cartpole.cartpole import \ | ||
ClassicControlCartPoleEnv, CUDAClassicControlCartPoleEnv | ||
from warp_drive.env_wrapper import EnvWrapper | ||
|
||
|
||
env_configs = { | ||
"test1": { | ||
"episode_length": 500, | ||
"reset_pool_size": 0, | ||
"seed": 32145, | ||
}, | ||
"test2": { | ||
"episode_length": 200, | ||
"reset_pool_size": 0, | ||
"seed": 54231, | ||
}, | ||
} | ||
|
||
|
||
class MyTestCase(unittest.TestCase): | ||
""" | ||
CPU v GPU consistency unit tests | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.testing_class = EnvironmentCPUvsGPU( | ||
cpu_env_class=ClassicControlCartPoleEnv, | ||
cuda_env_class=CUDAClassicControlCartPoleEnv, | ||
env_configs=env_configs, | ||
gpu_env_backend="numba", | ||
num_envs=5, | ||
num_episodes=2, | ||
) | ||
|
||
def test_env_consistency(self): | ||
try: | ||
self.testing_class.test_env_reset_and_step() | ||
except AssertionError: | ||
self.fail("ClassicControlCartPoleEnv environment consistency tests failed") | ||
|
||
def test_reset_pool(self): | ||
env_wrapper = EnvWrapper( | ||
env_obj=CUDAClassicControlCartPoleEnv(episode_length=100, reset_pool_size=3), | ||
num_envs=3, | ||
env_backend="numba", | ||
) | ||
env_wrapper.reset_all_envs() | ||
env_wrapper.env_resetter.init_reset_pool(env_wrapper.cuda_data_manager, seed=12345) | ||
self.assertTrue(env_wrapper.cuda_data_manager.reset_target_to_pool["state"] == "state_reset_pool") | ||
|
||
# squeeze() the agent dimension which is 1 always | ||
state_after_initial_reset = env_wrapper.cuda_data_manager.pull_data_from_device("state").squeeze() | ||
|
||
reset_pool = env_wrapper.cuda_data_manager.pull_data_from_device( | ||
env_wrapper.cuda_data_manager.get_reset_pool("state")) | ||
reset_pool_mean = reset_pool.mean(axis=0).squeeze() | ||
|
||
env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( | ||
np.array([1, 1, 0]) | ||
).cuda() | ||
|
||
state_values = {0: [], 1: [], 2: []} | ||
for _ in range(10000): | ||
env_wrapper.env_resetter.reset_when_done(env_wrapper.cuda_data_manager, mode="if_done", undo_done_after_reset=False) | ||
res = env_wrapper.cuda_data_manager.pull_data_from_device("state") | ||
state_values[0].append(res[0]) | ||
state_values[1].append(res[1]) | ||
state_values[2].append(res[2]) | ||
|
||
state_values_env0_mean = np.stack(state_values[0]).mean(axis=0).squeeze() | ||
state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze() | ||
state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze() | ||
|
||
for i in range(len(reset_pool_mean)): | ||
self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) | ||
self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) | ||
self.assertTrue( | ||
np.absolute( | ||
state_values_env2_mean[i] - state_after_initial_reset[0][i] | ||
) < 0.001 * abs(state_after_initial_reset[0][i]) | ||
) | ||
|
||
|
Oops, something went wrong.