-
Notifications
You must be signed in to change notification settings - Fork 381
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(xrk): add new env named Flozen Lake and DQN algorithm. (#781)
* first_commit * environment test pass * frame creative * change init file to new function * change code to fit the pr request * To ensure the environment operates correctly, consider adding more assertions for robust validation * now it can use serial_pipeline to make function * we change the code to justify new turnel * Compliance Check * add gif * format my code
- Loading branch information
1 parent
c999b07
commit aeb4c9c
Showing
9 changed files
with
502 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from ditk import logging | ||
from ding.model import DQN | ||
from ding.policy import DQNPolicy | ||
from ding.envs import DingEnvWrapper, BaseEnvManagerV2 | ||
from ding.data import DequeBuffer | ||
from ding.config import compile_config | ||
from ding.framework import task | ||
from ding.framework.context import OnlineRLContext | ||
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ | ||
eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver | ||
from ding.utils import set_pkg_seed | ||
from dizoo.frozen_lake.config.frozen_lake_dqn_config import main_config, create_config | ||
from dizoo.frozen_lake.envs import FrozenLakeEnv | ||
|
||
|
||
def main(): | ||
logging.getLogger().setLevel(logging.INFO) | ||
main_config.policy.nstep = 5 | ||
cfg = compile_config(main_config, create_cfg=create_config, auto=True) | ||
with task.start(async_mode=False, ctx=OnlineRLContext()): | ||
collector_env = BaseEnvManagerV2( | ||
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager | ||
) | ||
evaluator_env = BaseEnvManagerV2( | ||
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager | ||
) | ||
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | ||
|
||
model = DQN(**cfg.policy.model) | ||
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) | ||
policy = DQNPolicy(cfg.policy, model=model) | ||
|
||
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) | ||
task.use(eps_greedy_handler(cfg)) | ||
task.use(StepCollector(cfg, policy.collect_mode, collector_env)) | ||
task.use(nstep_reward_enhancer(cfg)) | ||
task.use(data_pusher(cfg, buffer_)) | ||
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) | ||
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) | ||
task.use(final_ctx_saver(cfg.exp_name)) | ||
task.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .frozen_lake_dqn_config import main_config, create_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from easydict import EasyDict | ||
|
||
frozen_lake_dqn_config = dict( | ||
exp_name='frozen_lake_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=5, | ||
n_evaluator_episode=10, | ||
env_id='FrozenLake-v1', | ||
desc=None, | ||
map_name="4x4", | ||
is_slippery=False, | ||
save_replay_gif=False, | ||
), | ||
policy=dict( | ||
cuda=True, | ||
load_path='frozen_lake_seed0/ckpt/ckpt_best.pth.tar', | ||
model=dict( | ||
obs_shape=16, | ||
action_shape=4, | ||
encoder_hidden_size_list=[128, 128, 64], | ||
dueling=True, | ||
), | ||
nstep=3, | ||
discount_factor=0.97, | ||
learn=dict( | ||
update_per_collect=5, | ||
batch_size=256, | ||
learning_rate=0.001, | ||
), | ||
collect=dict(n_sample=10), | ||
eval=dict(evaluator=dict(eval_freq=40, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=0.8, | ||
end=0.1, | ||
decay=10000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=20000, ), | ||
), | ||
), | ||
) | ||
|
||
frozen_lake_dqn_config = EasyDict(frozen_lake_dqn_config) | ||
main_config = frozen_lake_dqn_config | ||
|
||
frozen_lake_dqn_create_config = dict( | ||
env=dict( | ||
type='frozen_lake', | ||
import_names=['dizoo.frozen_lake.envs.frozen_lake_env'], | ||
), | ||
env_manager=dict(type='base'), | ||
policy=dict(type='dqn'), | ||
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']), | ||
) | ||
|
||
frozen_lake_dqn_create_config = EasyDict(frozen_lake_dqn_create_config) | ||
create_config = frozen_lake_dqn_create_config | ||
|
||
if __name__ == "__main__": | ||
# or you can enter `ding -m serial -c frozen_lake_dqn_config.py -s 0` | ||
from ding.entry import serial_pipeline | ||
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .frozen_lake_env import FrozenLakeEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from typing import Any, Dict, List, Optional | ||
import imageio | ||
import os | ||
import gymnasium as gymn | ||
import numpy as np | ||
from ding.envs import BaseEnv, BaseEnvTimestep | ||
from ding.torch_utils import to_ndarray | ||
from ding.utils import ENV_REGISTRY | ||
|
||
|
||
@ENV_REGISTRY.register('frozen_lake') | ||
class FrozenLakeEnv(BaseEnv): | ||
|
||
def __init__(self, cfg) -> None: | ||
self._cfg = cfg | ||
assert self._cfg.env_id == "FrozenLake-v1", "yout name is not FrozernLake_v1" | ||
self._init_flag = False | ||
self._save_replay_bool = False | ||
self._save_replay_count = 0 | ||
self._init_flag = False | ||
self._frames = [] | ||
self._replay_path = False | ||
|
||
def reset(self) -> np.ndarray: | ||
if not self._init_flag: | ||
if not self._cfg.desc: #specify maps non-preloaded maps | ||
self._env = gymn.make( | ||
self._cfg.env_id, | ||
desc=self._cfg.desc, | ||
map_name=self._cfg.map_name, | ||
is_slippery=self._cfg.is_slippery, | ||
render_mode="rgb_array" | ||
) | ||
self._observation_space = self._env.observation_space | ||
self._action_space = self._env.action_space | ||
self._reward_space = gymn.spaces.Box( | ||
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 | ||
) | ||
self._init_flag = True | ||
self._eval_episode_return = 0 | ||
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: | ||
np_seed = 100 * np.random.randint(1, 1000) | ||
self._env_seed = self._seed + np_seed | ||
elif hasattr(self, '_seed'): | ||
self._env_seed = self._seed | ||
if hasattr(self, '_seed'): | ||
obs, info = self._env.reset(seed=self._env_seed) | ||
else: | ||
obs, info = self._env.reset() | ||
obs = np.eye(16, dtype=np.float32)[obs - 1] | ||
return obs | ||
|
||
def close(self) -> None: | ||
if self._init_flag: | ||
self._env.close() | ||
self._init_flag = False | ||
|
||
def seed(self, seed: int, dynamic_seed: bool = True) -> None: | ||
self._seed = seed | ||
self._dynamic_seed = dynamic_seed | ||
np.random.seed(self._seed) | ||
|
||
def step(self, action: Dict) -> BaseEnvTimestep: | ||
obs, rew, terminated, truncated, info = self._env.step(action[0]) | ||
self._eval_episode_return += rew | ||
obs = np.eye(16, dtype=np.float32)[obs - 1] | ||
rew = to_ndarray([rew]) | ||
if self._save_replay_bool: | ||
picture = self._env.render() | ||
self._frames.append(picture) | ||
if terminated or truncated: | ||
done = True | ||
else: | ||
done = False | ||
if done: | ||
info['eval_episode_return'] = self._eval_episode_return | ||
if self._save_replay_bool: | ||
assert self._replay_path is not None, "your should have a path" | ||
path = os.path.join( | ||
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count) | ||
) | ||
self.frames_to_gif(self._frames, path) | ||
self._frames = [] | ||
self._save_replay_count += 1 | ||
rew = rew.astype(np.float32) | ||
return BaseEnvTimestep(obs, rew, done, info) | ||
|
||
def random_action(self) -> Dict: | ||
raw_action = self._env.action_space.sample() | ||
my_type = type(self._env.action_space) | ||
return [raw_action] | ||
|
||
def __repr__(self) -> str: | ||
return "DI-engine Frozen Lake Env" | ||
|
||
@property | ||
def observation_space(self) -> gymn.spaces.Space: | ||
return self._observation_space | ||
|
||
@property | ||
def action_space(self) -> gymn.spaces.Space: | ||
return self._action_space | ||
|
||
@property | ||
def reward_space(self) -> gymn.spaces.Space: | ||
return self._reward_space | ||
|
||
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: | ||
if replay_path is None: | ||
replay_path = './video' | ||
self._replay_path = replay_path | ||
self._save_replay_bool = True | ||
self._save_replay_count = 0 | ||
self._frames = [] | ||
|
||
@staticmethod | ||
def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None: | ||
""" | ||
Convert a list of frames into a GIF. | ||
Args: | ||
- frames (List[imageio.core.util.Array]): A list of frames, each frame is an image. | ||
- gif_path (str): The path to save the GIF file. | ||
- duration (float): Duration between each frame in the GIF (seconds). | ||
Returns: | ||
None, the GIF file is saved directly to the specified path. | ||
""" | ||
# Save all frames as temporary image files | ||
temp_image_files = [] | ||
for i, frame in enumerate(frames): | ||
temp_image_file = f"frame_{i}.png" # Temporary file name | ||
imageio.imwrite(temp_image_file, frame) # Save the frame as a PNG file | ||
temp_image_files.append(temp_image_file) | ||
|
||
# Use imageio to convert temporary image files to GIF | ||
with imageio.get_writer(gif_path, mode='I', duration=duration) as writer: | ||
for temp_image_file in temp_image_files: | ||
image = imageio.imread(temp_image_file) | ||
writer.append_data(image) | ||
|
||
# Clean up temporary image files | ||
for temp_image_file in temp_image_files: | ||
os.remove(temp_image_file) | ||
print(f"GIF saved as {gif_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import numpy as np | ||
import pytest | ||
from dizoo.frozen_lake.envs import FrozenLakeEnv | ||
from easydict import EasyDict | ||
|
||
|
||
@pytest.mark.envtest | ||
class TestGymHybridEnv: | ||
|
||
def test_my_lake(self): | ||
env = FrozenLakeEnv( | ||
EasyDict({ | ||
'env_id': 'FrozenLake-v1', | ||
'desc': None, | ||
'map_name': "4x4", | ||
'is_slippery': False, | ||
}) | ||
) | ||
for _ in range(5): | ||
env.seed(314, dynamic_seed=False) | ||
assert env._seed == 314 | ||
obs = env.reset() | ||
assert obs.shape == ( | ||
16, | ||
), "Considering the one-hot encoding format, your observation should have a dimensionality of 16." | ||
for i in range(10): | ||
env.enable_save_replay("./video") | ||
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space, | ||
# can generate legal random action. | ||
if i < 5: | ||
random_action = np.array([env.action_space.sample()]) | ||
else: | ||
random_action = env.random_action() | ||
timestep = env.step(random_action) | ||
print(timestep) | ||
assert isinstance(timestep.obs, np.ndarray) | ||
assert isinstance(timestep.done, bool) | ||
assert timestep.obs.shape == (16, ) | ||
assert timestep.reward.shape == (1, ) | ||
assert timestep.reward >= env.reward_space.low | ||
assert timestep.reward <= env.reward_space.high | ||
|
||
print(env.observation_space, env.action_space, env.reward_space) | ||
env.close() |