diff --git a/artifacts/rollout_policy_10x10_4v2.pt b/artifacts/rollout_policy_10x10_4v2.pt index 98729eb..9b9b796 100644 Binary files a/artifacts/rollout_policy_10x10_4v2.pt and b/artifacts/rollout_policy_10x10_4v2.pt differ diff --git a/ma_gym/envs/predator_prey/predator_prey.py b/ma_gym/envs/predator_prey/predator_prey.py index 150d947..6ba59d6 100644 --- a/ma_gym/envs/predator_prey/predator_prey.py +++ b/ma_gym/envs/predator_prey/predator_prey.py @@ -378,6 +378,13 @@ def substep(self, agent_id, action): return self.get_agent_obs(), rewards, self._agent_dones, {'prey_alive': self._prey_alive} + def apply_move(self, agent_id, action): + # one agent moves + if not (self._agent_dones[agent_id]): + self.__update_agent_pos(agent_id, action) + + return self.get_agent_obs() + def __get_neighbour_coordinates(self, pos): neighbours = [] if self.is_valid([pos[0] + 1, pos[1]]): @@ -428,39 +435,6 @@ def close(self): self.viewer.close() self.viewer = None - def get_distances(self): - distances = [] - - n_actions = len(ACTION_MEANING) - - for agent_curr_pos in self.agent_pos.values(): - # initialize to inf (max distance) - a_distances = np.full(shape=(n_actions, self.n_preys), fill_value=np.inf, dtype=np.float32) - - for action in ACTION_MEANING: - # apply selected action to the current position - modified_agent_pos = self._apply_action(agent_curr_pos, action) - if modified_agent_pos is not None: - for j, p_pos in self.prey_pos.items(): - if self._prey_alive[j]: - # calc MD - md = np.abs(p_pos[0] - modified_agent_pos[0]) + np.abs(p_pos[1] - modified_agent_pos[1]) - a_distances[action, j] = md - - # post-processing: replace dist from invalid moves (inf) with distance of MAX+1 - for col in a_distances.T: - if np.inf in col and not np.all(col == np.inf): - max_dist = np.max(col[col != np.inf]) - col[col == np.inf] = max_dist + 1 - - # check that to action yields (inf, inf) - has_inf = np.all(a_distances == np.inf, axis=1) - assert True not in has_inf - - distances.append(a_distances) - - return distances - def _apply_action(self, curr_pos, move): # curr_pos = copy.copy(self.agent_pos[agent_i]) if move == 0: # down diff --git a/scripts/learn_rollout_policy_online.py b/scripts/learn_rollout_policy_online.py index e69de29..4827fe7 100644 --- a/scripts/learn_rollout_policy_online.py +++ b/scripts/learn_rollout_policy_online.py @@ -0,0 +1,269 @@ +import time +from typing import List, Iterable + +from tqdm import tqdm +import numpy as np +import gym +import ma_gym # register new envs on import +import torch +import torch.nn as nn +import torch.optim as optim + +from src.constants import SpiderAndFlyEnv, RolloutModelPath_10x10_4v2 +from src.qnetwork import QNetwork +from src.agent_rule_based import RuleBasedAgent +from src.agent_approx_rollout import RolloutAgent + +N_EPISODES = 3 + +N_SIMS_PER_STEP = 10 + +BATCH_SIZE = 512 +EPOCHS = 1 + + +def simulate( + initial_obs: np.array, + initial_step: np.array, + m_agents: int, + qnet, + fake_env, + action_space, +) -> float: + avg_total_reward = .0 + + # create env + env = gym.make(SpiderAndFlyEnv) + + # run N simulations + for _ in range(N_SIMS_PER_STEP): + obs_n = env.reset_from(initial_obs) + + # 1 step + obs_n, reward_n, done_n, info = env.step(initial_step) + avg_total_reward += np.sum(reward_n) + + # run an episode until all prey is caught + while not all(done_n): + + # all agents act based on the observation + act_n = [] + + prev_actions = {} + + for agent_id, obs in enumerate(obs_n): + obs_after_coordination = update_obs(fake_env, obs, prev_actions) + + action_taken = epsilon_greedy_step( + obs_after_coordination, m_agents, agent_id, qnet, action_space) + + prev_actions[agent_id] = action_taken + act_n.append(action_taken) + + # update step + obs_n, reward_n, done_n, info = env.step(act_n) + + avg_total_reward += np.sum(reward_n) + + env.close() + + avg_total_reward /= m_agents + avg_total_reward /= N_SIMS_PER_STEP + + return avg_total_reward + + +def convert_to_x(obs, m_agents, agent_id): + # state + obs_first = np.array(obs, dtype=np.float32).flatten() + + # agent ohe + agent_ohe = np.zeros(shape=(m_agents,), dtype=np.float32) + agent_ohe[agent_id] = 1. + + x = np.concatenate((obs_first, agent_ohe)) + + return x + + +def epsilon_greedy_step(obs, m_agents, agent_id, qnet, action_space, epsilon=0.05): + p = np.random.random() + if p < epsilon: + # random action -> exploration + return action_space.sample() + else: + qnet.eval() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # argmax -> exploitation + x = convert_to_x(obs, m_agents, agent_id) + x = np.reshape(x, newshape=(1, -1)) + # v = torch.from_numpy(x) + v = torch.tensor(x, device=device) + qs = qnet(v) + return np.argmax(qs.data.cpu().numpy()) + + +def epsilon_greedy_step_from_array(qvalues, action_space, epsilon=0.05): + p = np.random.random() + if p < epsilon: + # random action -> exploration + return action_space.sample() + else: + # argmax -> exploitation + return np.argmax(qvalues) + + +def update_obs(env, obs, prev_actions): + if prev_actions: + obs_new = env.reset_from(obs) + for agent_id, action in prev_actions.items(): + obs_new = env.apply_move(agent_id, action) + + return obs_new[0] + else: + return obs + + +def train_qnet(qnet, samples): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + qnet.train() + + criterion = nn.MSELoss() + optimizer = optim.Adam(qnet.parameters(), lr=0.01) + data_loader = torch.utils.data.DataLoader(samples, + batch_size=BATCH_SIZE, + shuffle=True) + + for epoch in range(EPOCHS): # loop over the dataset multiple times + running_loss = .0 + n_batches = 0 + + for data in data_loader: + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + outputs = qnet(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # logging + running_loss += loss.item() + n_batches += 1 + + print(f'[{epoch}] {running_loss / n_batches:.3f}.') + + return qnet + + +def learn_policy(): + np.random.seed(42) + + # create Spiders-and-Flies game + env = gym.make(SpiderAndFlyEnv) + env.seed(42) + + # used only for modeling + fake_env = gym.make(SpiderAndFlyEnv) + fake_env.seed(1) + + # init env variables + m_agents = env.n_agents + p_preys = env.n_preys + grid_shape = env._grid_shape + action_space = env.action_space[0] + + # base net + # base_qnet = QNetwork(m_agents, p_preys, action_space.n) + # base_qnet.load_state_dict(torch.load(RolloutModelPath_10x10_4v2)) + # base_qnet.eval() + + # rollout net + rollout_qnet = QNetwork(m_agents, p_preys, action_space.n) + rollout_qnet.load_state_dict(torch.load(RolloutModelPath_10x10_4v2)) + # rollout_qnet.train() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + rollout_qnet.to(device) + + for _ in tqdm(range(N_EPISODES)): + # init env + obs_n = env.reset() + + # init stopping condition + done_n = [False] * env.n_agents + + # run an episode until all prey is caught + while not all(done_n): + prev_actions = {} + act_n = [] + + samples = [] + + for agent_id, obs in enumerate(obs_n): + n_actions = action_space.n + q_values = np.full((n_actions,), fill_value=-np.inf, dtype=np.float32) + + new_actions = {} + + for action_id in range(n_actions): + # 1st step - optimal actions from previous agents, + # simulated step from current agent, + # greedy (baseline) from undecided agents + sub_actions = np.empty((m_agents,), dtype=np.int8) + + for i in range(m_agents): + if i in prev_actions: + sub_actions[i] = prev_actions[i] + elif agent_id == i: + sub_actions[i] = action_id + + new_actions[i] = action_id + else: + # update obs with info about prev steps + obs_after_coordination = update_obs(fake_env, obs, {**prev_actions, **new_actions}) + best_action = epsilon_greedy_step( + obs_after_coordination, m_agents, i, rollout_qnet, action_space) + + sub_actions[i] = best_action + new_actions[i] = best_action + + # run N simulations + avg_total_reward = simulate(obs, sub_actions, m_agents, rollout_qnet, fake_env, action_space) + + q_values[action_id] = avg_total_reward + + # adds sample to the dataset + agent_ohe = np.zeros(shape=(m_agents,), dtype=np.float32) + agent_ohe[agent_id] = 1. + obs_after_coordination = np.array(update_obs(fake_env, obs, prev_actions), dtype=np.float32) + x = np.concatenate((obs_after_coordination, agent_ohe)) + samples.append((x, q_values)) + + # TODO sanity check + # print(f'Qnet: {q_values}') + # print(f'MC: {}') + + # current policy + action_taken = epsilon_greedy_step_from_array(q_values, action_space) + + prev_actions[agent_id] = action_taken + act_n.append(action_taken) + + # update step + obs_n, reward_n, done_n, info = env.step(act_n) + + # update rollout policy with samples + rollout_qnet = train_qnet(rollout_qnet, samples) + + env.close() + + # save updated qnet + torch.save(rollout_qnet.state_dict(), RolloutModelPath_10x10_4v2) + + +# ------------- Runner ------------- + +if __name__ == '__main__': + learn_policy()