Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a circle detection as penalty to agent and add mps support to the MLP #36

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions main/circle_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
def is_surrounded(board, flag, row, col, circle):
"""
Check if 0 cell on the board is completely surrounded by 1.

For each 0 search the surrounding 0,
if there is a 0 at the boundary during the search indicates that it is not surrounded
"""

def reach_edge(i, j):
if i == 0 or i == len(board) - 1:
return True
elif j == 0 or j == len(board[0]) - 1:
return True
else:
return False

directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
for dr, dc in directions:
new_row = row + dr
new_col = col + dc
if new_row < 0 or new_col < 0 or new_row == len(board) or new_col == len(board[0]):
continue
if reach_edge(new_row, new_col) and board[new_row][new_col] == 0:
flag[new_row][new_col] = -1
circle.clear()
return False
else:
if board[new_row][new_col] == 0 and flag[new_row][new_col] == 0:
flag[new_row][new_col] = -1
circle.add((new_row, new_col))
if not is_surrounded(board, flag, new_row, new_col, circle):
circle.clear()
return False
else:
continue
else:
continue
return True


def is_circle(board):
flag = board
circle = set()
for i in range(1, len(board)-1):
for j in range(1, len(board[0])-1):
if board[i][j] == 0 and flag[i][j] == 0:
# For each unvisited zero, mark it and check if surrounded
flag[i][j] = -1
circle.add((i, j))
if not is_surrounded(board, flag, i, j, circle):
circle.clear()
else:
continue
return len(circle) # return the number of 0 surrounded by 1


if __name__ == "__main__":

# test
board = [
[1, 0, 1, 1, 1],
[1, 1, 1, 0, 1],
[1, 1, 1, 0, 1],
[1, 1, 1, 0, 1],
[1, 1, 1, 1, 1]
]
print(is_circle(board))
Binary file not shown.
Binary file not shown.
Binary file not shown.
13 changes: 11 additions & 2 deletions main/snake_game_custom_wrapper_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from snake_game import SnakeGame
from circle_detection import is_circle

class SnakeEnv(gym.Env):
def __init__(self, seed=0, board_size=12, silent_mode=True, limit_step=True):
Expand Down Expand Up @@ -47,6 +48,9 @@ def step(self, action):
self.done, info = self.game.step(action) # info = {"snake_size": int, "snake_head_pos": np.array, "prev_snake_head_pos": np.array, "food_pos": np.array, "food_obtained": bool}
obs = self._generate_observation()

compressed_obs = np.zeros((self.game.board_size, self.game.board_size), dtype=np.uint8)
compressed_obs[tuple(np.transpose(self.game.snake))] = 1

reward = 0.0
self.reward_step_counter += 1

Expand All @@ -72,12 +76,17 @@ def step(self, action):
self.reward_step_counter = 0 # Reset reward step counter

else:
# reward_circle = - is_circle(compressed_obs) * info["snake_size"]

# Give a penalty to the agent if its body surrounds spaces of the board
reward_circle = - (is_circle(compressed_obs) * 0.1) / info["snake_size"] # '* 0.1' stands for not competing with each reward/penalty after each step

# Give a tiny reward/penalty to the agent based on whether it is heading towards the food or not.
# Not competing with game over penalty or the food eaten reward.
if np.linalg.norm(info["snake_head_pos"] - info["food_pos"]) < np.linalg.norm(info["prev_snake_head_pos"] - info["food_pos"]):
reward = 1 / info["snake_size"]
reward = (1 / info["snake_size"]) + reward_circle
else:
reward = - 1 / info["snake_size"]
reward = (- 1 / info["snake_size"]) + reward_circle
reward = reward * 0.1

# max_score: 72 + 14.1 = 86.1
Expand Down
11 changes: 9 additions & 2 deletions main/snake_game_custom_wrapper_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from snake_game import SnakeGame
from circle_detection import is_circle

class SnakeEnv(gym.Env):
def __init__(self, seed=0, board_size=12, silent_mode=True, limit_step=True):
Expand Down Expand Up @@ -46,6 +47,9 @@ def step(self, action):
self.done, info = self.game.step(action) # info = {"snake_size": int, "snake_head_pos": np.array, "prev_snake_head_pos": np.array, "food_pos": np.array, "food_obtained": bool}
obs = self._generate_observation()

compressed_obs = np.zeros((self.game.board_size, self.game.board_size), dtype=np.uint8)
compressed_obs[tuple(np.transpose(self.game.snake))] = 1

reward = 0.0
self.reward_step_counter += 1

Expand All @@ -68,10 +72,13 @@ def step(self, action):
self.reward_step_counter = 0 # Reset reward step counter

else:
# Give a penalty to the agent if its body surrounds spaces of the board
reward_circle = - (is_circle(compressed_obs) * 0.1) / info["snake_size"] # '* 0.1' stands for not competing with each reward/penalty after each step

if np.linalg.norm(info["snake_head_pos"] - info["food_pos"]) < np.linalg.norm(info["prev_snake_head_pos"] - info["food_pos"]):
reward = 1 / info["snake_size"] # No upper limit might enable the agent to master shorter scenario faster and more firmly.
reward = (1 / info["snake_size"]) + reward_circle # No upper limit might enable the agent to master shorter scenario faster and more firmly.
else:
reward = - 1 / info["snake_size"]
reward = (- 1 / info["snake_size"]) + reward_circle
# print(reward*0.1)
# time.sleep(1)

Expand Down
6 changes: 5 additions & 1 deletion main/test_mlp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import time
import random

import torch
from sb3_contrib import MaskablePPO

from snake_game_custom_wrapper_mlp import SnakeEnv

MODEL_PATH = r"trained_models_mlp/ppo_snake_final"
if torch.backends.mps.is_available():
MODEL_PATH = r"trained_models_mlp_mps/ppo_snake_final"
else:
MODEL_PATH = r"trained_models_mlp/ppo_snake_final"

NUM_EPISODE = 10

Expand Down
67 changes: 46 additions & 21 deletions main/train_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import random

import torch
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback
Expand All @@ -10,7 +11,10 @@

from snake_game_custom_wrapper_mlp import SnakeEnv

NUM_ENV = 32
if torch.backends.mps.is_available():
NUM_ENV = 32 * 2
else:
NUM_ENV = 32
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)

Expand Down Expand Up @@ -46,26 +50,47 @@ def main():
# Create the Snake environment.
env = SubprocVecEnv([make_env(seed=s) for s in seed_set])

lr_schedule = linear_schedule(2.5e-4, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)

# # Instantiate a PPO agent
model = MaskablePPO(
"MlpPolicy",
env,
device="cuda",
verbose=1,
n_steps=2048,
batch_size=512,
n_epochs=4,
gamma=0.94,
learning_rate=lr_schedule,
clip_range=clip_range_schedule,
tensorboard_log=LOG_DIR
)

# Set the save directory
save_dir = "trained_models_mlp"
if torch.backends.mps.is_available():
lr_schedule = linear_schedule(5e-4, 2.5e-6)
clip_range_schedule = linear_schedule(0.150, 0.025)
# Instantiate a PPO agent using MPS (Metal Performance Shaders).
model = MaskablePPO(
"MlpPolicy",
env,
device="mps",
verbose=1,
n_steps=2048,
batch_size=512*8,
n_epochs=4,
gamma=0.94,
learning_rate=lr_schedule,
clip_range=clip_range_schedule,
tensorboard_log=LOG_DIR
)
else:
lr_schedule = linear_schedule(2.5e-4, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)

# # Instantiate a PPO agent
model = MaskablePPO(
"MlpPolicy",
env,
device="cuda",
verbose=1,
n_steps=2048,
batch_size=512,
n_epochs=4,
gamma=0.94,
learning_rate=lr_schedule,
clip_range=clip_range_schedule,
tensorboard_log=LOG_DIR
)

# Set the save directory
if torch.backends.mps.is_available():
save_dir = "trained_models_mlp_mps"
else:
save_dir = "trained_models_mlp"
os.makedirs(save_dir, exist_ok=True)

checkpoint_interval = 15625 # checkpoint_interval * num_envs = total_steps_per_checkpoint
Expand Down
Binary file modified main/trained_models_cnn/ppo_snake_final.zip
Binary file not shown.
Binary file modified main/trained_models_cnn_mps/ppo_snake_final.zip
Binary file not shown.
Binary file modified main/trained_models_mlp/ppo_snake_final.zip
Binary file not shown.
Binary file added main/trained_models_mlp_mps/ppo_snake_final.zip
Binary file not shown.
Loading