This tool is designed to train and evaluate reinforcement learning agents for epidemic control simulations based on stochastic discrete epidemic models. The agents are implemented using model-free off-policy methods. Specifically, we employ tabular Q-Learning and Deep Q-Networks (DQN) to learn policies for controlling the spread of an epidemic for a single classroom operation.
- Python 3.8 or higher
- pip
- virtualenv
-
Clone the repository:
git clone https://github.com/yourusername/SafeCampus-Multidiscrete.git cd SafeCampus-Multidiscrete
-
Create a virtual environment:
python3 -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install the required packages:
pip install -r requirements.txt
-
Install the custom Gym environment:
cd campus_gym pip install -e . cd ..
Configuration files for the shared settings and specific agents are located in the config
directory.
The shared configuration is specified in config/config_shared.yaml
. It includes settings that are common across different agents, such as environment details and logging configurations.
Agent-specific configurations are specified in separate YAML files, such as config/config_q_learning.yaml
for Q-Learning and config/config_dqn.yaml
for DQN. These files contain settings specific to each agent, including hyperparameters.
To train an agent, use the following command:
python main.py train --alpha 0.5 --agent_type q_learning
To evaluate a trained agent, use the following command:
python main.py eval --alpha 0.5 --agent_type q_learning --run_name your_run_name
or for DQN:
python main.py eval --alpha 0.5 --agent_type dqn --run_name your_run_name
To run multiple training runs and calculate tolerance intervals, use:
python main.py multi --alpha_t 0.05 --beta_t 0.9 --num_runs 5 --agent_type q_learning
or for DQN:
python main.py multi --alpha_t 0.05 --beta_t 0.9 --num_runs 5 --agent_type dqn
To perform hyperparameter optimization using Optuna, use:
python main.py optuna --agent_type q_learning
or for DQN:
python main.py optuna --agent_type dqn