This tool is designed to train and evaluate reinforcement learning agents for epidemic control simulations based on stochastic discrete epidemic models. The agents are implemented using model-free off-policy methods. Specifically, we employ tabular Q-Learning and Deep Q-Networks (DQN) to learn policies for controlling the spread of an epidemic for a single classroom operation.
Ensure you have the following installed:
- Python 3.8 or higher
- pip (Python package installer)
- Conda (for environment management)
First, clone the project repository from GitHub:
git clone https://github.com/your-github-username/your-repo-name.git
cd your-repo-name
Create and activate a new conda environment:
conda create -n campus_gym
conda activate campus_gym
Install the required packages using pip:
pip install -r requirements.txt
If you haven't already, sign up for a wandb account at https://wandb.ai. Then, log in via the command line:
wandb login
Ensure all configuration files are in place and update the wandb settings with your username and project name. The configuration files are located in the config
directory and include the following:
config/config_shared.yaml
: Shared configuration settingsconfig/config_{agent_type}.yaml
: Agent-specific configurationsconfig/optuna_config.yaml
: Optuna hyperparameter optimization settings (if using)
Prepare Community Risk CSV (if applicable)
If you're planning to use community risk data from a CSV file, ensure it's placed in the appropriate directory and update the csv_path
argument when running the script.
You can run the project in different modes:
- Training:
python main.py train --agent_type q_learning --alpha 0.8 --algorithm q_learning
- Evaluation: The CSV file to use is in the root directory of the projects. This is the 'aggregated_weekly_risk_levels.csv'
python main.py eval --agent_type q_learning --alpha 0.8 --run_name your_run_name --csv_path aggregated_weekly_risk_levels.csv --algorithm q_learning
- Combined Training and Evaluation:
python main.py train_and_eval --agent_type q_learning --alpha 0.8 --csv_path aggregated_weekly_risk_levels.csv --algorithm q_learning
- Hyperparameter Sweep:
python main.py sweep --agent_type q_learning
- Multiple Runs:
python main.py multi --agent_type q_learning --alpha_t 0.05 --beta_t 0.9 --num_runs 10
- Optuna Optimization:
python main.py optuna --agent_type q_learning
Replace q_learning
with dqn
if you want to use the DQN algorithm instead.
If you encounter any issues during installation or running the project, please check the following:
- Ensure all prerequisites are correctly installed.
- Verify that you're in the correct conda environment (
conda activate campus_gym
). - Check that all configuration files are properly set up.
- Make sure you have the necessary permissions to read/write in the project directory.
If problems persist, please open an issue on the GitHub repository with details about the error and your environment.
This section provides detailed information on how to view, analyze, and interpret the results generated by the project. The results include training metrics, evaluation metrics, visualizations, and safety analysis.
The project uses Weights & Biases (W&B) to log and visualize training and evaluation metrics. You can view the results by logging into your W&B account and navigating to the project dashboard.
- Cumulative Reward: The total reward accumulated over the course of an episode.
For each training and evaluation run, metrics are also saved locally as CSV files in the results
directory. The following CSV files are generated:
training_metrics_<run_name>.csv
: Contains training metrics for each episode.evaluation_metrics_<run_name>.csv
: Contains evaluation metrics for each step within an episode.mean_allowed_infected.csv
: Summarizes the mean allowed and infected values across all episodes.
These files can be found in the results subdirectory specific to each run, typically located at <results_directory>/<agent_type>/<run_name>/<timestamp>/
.
All visualizations are saved in the results subdirectory for each run. The specific path is <results_directory>/<agent_type>/<run_name>/<timestamp>/
.
The Tolerance Interval Curve shows the range of expected returns within a specified confidence level (alpha
) and proportion (beta
). This curve helps visualize the performance consistency of the model across different runs.
tolerance_interval_mean.png
: The mean performance with the tolerance interval.tolerance_interval_median.png
: The median performance with the tolerance interval.
The Confidence Interval Curve shows the mean performance of the model with a confidence interval, typically at 95%. This visualization helps assess the reliability of the model's performance.
confidence_interval.png
: The confidence interval for mean performance.
The Safety Set Identification plot shows the states where the model's policy maintains safety constraints, such as keeping the number of infections below a threshold.
safety_set_plot_episode_<run_name>.png
: A plot showing the safety set for a specific episode.
For Q-learning, an evaluation plot is generated showing the allowed students, infected individuals, and community risk over time.
evaluation_plot_<run_name>.png
: Plot of evaluation results.
The safety set conditions are logged in the safety_conditions_<run_name>.csv
file. This file contains the following columns:
- Episode: The episode number.
- Infections > Threshold (%): The percentage of time the number of infections exceeded the threshold.
- Safety Condition Met (Infection): Indicates whether the infection safety condition was met.
- Allowed Students ≥ Threshold (%): The percentage of time the allowed students met or exceeded the threshold.
- Safety Condition Met (Attendance): Indicates whether the attendance safety condition was met.
This file provides a quick overview of the model's performance and safety compliance across the entire evaluation period.
Use the above results and visualizations to assess and refine your model. The combination of metrics, visualizations, and safety analysis will help you understand the model's strengths and weaknesses and guide further development. Remember to check the specific results subdirectory for each run to find all the generated files and visualizations.