Skip to content

ANRGUSC/SafeCampus-Multidiscrete

Repository files navigation

SafeCampus-Multidiscrete

This tool is designed to train and evaluate reinforcement learning agents for epidemic control simulations based on stochastic discrete epidemic models. The agents are implemented using model-free off-policy methods. Specifically, we employ tabular Q-Learning and Deep Q-Networks (DQN) to learn policies for controlling the spread of an epidemic for a single classroom operation.

Installation

Prerequisites

Ensure you have the following installed:

  • Python 3.8 or higher
  • pip (Python package installer)
  • Conda (for environment management)

Installation Steps

1. Clone the Repository

First, clone the project repository from GitHub:

git clone https://github.com/your-github-username/your-repo-name.git
cd your-repo-name

2. Set Up Conda Environment

Create and activate a new conda environment:

conda create -n campus_gym
conda activate campus_gym

3. Install Dependencies

Install the required packages using pip:

pip install -r requirements.txt

4. Configure Weights & Biases (wandb)

If you haven't already, sign up for a wandb account at https://wandb.ai. Then, log in via the command line:

wandb login

5. Set Up Configuration Files

Ensure all configuration files are in place and update the wandb settings with your username and project name. The configuration files are located in the config directory and include the following:

  • config/config_shared.yaml: Shared configuration settings
  • config/config_{agent_type}.yaml: Agent-specific configurations
  • config/optuna_config.yaml: Optuna hyperparameter optimization settings (if using)

Prepare Community Risk CSV (if applicable) If you're planning to use community risk data from a CSV file, ensure it's placed in the appropriate directory and update the csv_path argument when running the script.

Running the Project

You can run the project in different modes:

  1. Training:
python main.py train --agent_type q_learning --alpha 0.8 --algorithm q_learning
  1. Evaluation: The CSV file to use is in the root directory of the projects. This is the 'aggregated_weekly_risk_levels.csv'
python main.py eval --agent_type q_learning --alpha 0.8 --run_name your_run_name --csv_path aggregated_weekly_risk_levels.csv --algorithm q_learning
  1. Combined Training and Evaluation:
python main.py train_and_eval --agent_type q_learning --alpha 0.8 --csv_path aggregated_weekly_risk_levels.csv --algorithm q_learning
  1. Hyperparameter Sweep:
python main.py sweep --agent_type q_learning
  1. Multiple Runs:
python main.py multi --agent_type q_learning --alpha_t 0.05 --beta_t 0.9 --num_runs 10
  1. Optuna Optimization:
python main.py optuna --agent_type q_learning

Replace q_learning with dqn if you want to use the DQN algorithm instead.

Troubleshooting

If you encounter any issues during installation or running the project, please check the following:

  1. Ensure all prerequisites are correctly installed.
  2. Verify that you're in the correct conda environment (conda activate campus_gym).
  3. Check that all configuration files are properly set up.
  4. Make sure you have the necessary permissions to read/write in the project directory.

If problems persist, please open an issue on the GitHub repository with details about the error and your environment.

Results

This section provides detailed information on how to view, analyze, and interpret the results generated by the project. The results include training metrics, evaluation metrics, visualizations, and safety analysis.

1. Viewing Results

1.1. Weights & Biases Integration

The project uses Weights & Biases (W&B) to log and visualize training and evaluation metrics. You can view the results by logging into your W&B account and navigating to the project dashboard.

1.1.1. Metrics Logged

  • Cumulative Reward: The total reward accumulated over the course of an episode.

1.2. Local CSV Files

For each training and evaluation run, metrics are also saved locally as CSV files in the results directory. The following CSV files are generated:

  • training_metrics_<run_name>.csv: Contains training metrics for each episode.
  • evaluation_metrics_<run_name>.csv: Contains evaluation metrics for each step within an episode.
  • mean_allowed_infected.csv: Summarizes the mean allowed and infected values across all episodes.

These files can be found in the results subdirectory specific to each run, typically located at <results_directory>/<agent_type>/<run_name>/<timestamp>/.

2. Visualizations

All visualizations are saved in the results subdirectory for each run. The specific path is <results_directory>/<agent_type>/<run_name>/<timestamp>/.

2.1. Tolerance Interval Curve

The Tolerance Interval Curve shows the range of expected returns within a specified confidence level (alpha) and proportion (beta). This curve helps visualize the performance consistency of the model across different runs.

  • tolerance_interval_mean.png: The mean performance with the tolerance interval.
  • tolerance_interval_median.png: The median performance with the tolerance interval.

2.2. Confidence Interval Curve

The Confidence Interval Curve shows the mean performance of the model with a confidence interval, typically at 95%. This visualization helps assess the reliability of the model's performance.

  • confidence_interval.png: The confidence interval for mean performance.

2.3. Safety Set Identification

The Safety Set Identification plot shows the states where the model's policy maintains safety constraints, such as keeping the number of infections below a threshold.

  • safety_set_plot_episode_<run_name>.png: A plot showing the safety set for a specific episode.

2.4. Evaluation Results Plot

For Q-learning, an evaluation plot is generated showing the allowed students, infected individuals, and community risk over time.

  • evaluation_plot_<run_name>.png: Plot of evaluation results.

3. Safety Analysis

3.1. Safety Set Conditions

The safety set conditions are logged in the safety_conditions_<run_name>.csv file. This file contains the following columns:

  • Episode: The episode number.
  • Infections > Threshold (%): The percentage of time the number of infections exceeded the threshold.
  • Safety Condition Met (Infection): Indicates whether the infection safety condition was met.
  • Allowed Students ≥ Threshold (%): The percentage of time the allowed students met or exceeded the threshold.
  • Safety Condition Met (Attendance): Indicates whether the attendance safety condition was met.

This file provides a quick overview of the model's performance and safety compliance across the entire evaluation period.


Use the above results and visualizations to assess and refine your model. The combination of metrics, visualizations, and safety analysis will help you understand the model's strengths and weaknesses and guide further development. Remember to check the specific results subdirectory for each run to find all the generated files and visualizations.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published