This repository contains the necessary scripts and configurations to train machine learning models on particle physics datasets using PyTorch Lightning and various custom utilities. The main script, as provided, handles complex tasks such as setting up models, loading data, and running training procedures with support for logging and checkpointing.
To run the training script, you need the following:
- Python 3.8 or higher
- PyTorch 1.8 or higher
- PyTorch Lightning
- Wandb for logging and monitoring
- YAML for configuration management
- NumPy, pandas, scipy for data manipulation and computation
- tqdm for progress bars
- Clone the repository:
git clone git@github.com:kaechb/MDMA.git cd MDMA
- Install the required packages:
pip install -r requirements.txt
The script uses YAML files for configuration (hparams/default_[model_name].yaml
). Update these configuration files according to your model and dataset specifics. Parameters include model type, dataset, optimizer settings, and more.
To train a model, navigate to the directory containing the training script and execute:
python main.py [model_name]
Replace [model_name]
with the specific model configuration name you want to use, which corresponds to a YAML file.
setup_model
: Configures the model based on the YAML settings.train
: Manages the training process, including data loading, model initialization, and training loop execution.setup_scaler_calo
: Configures data scaling based on dataset characteristics.
- Wandb: Ensure you have an active Wandb account and your environment is set up to log metrics and outputs. The script initializes Wandb logging and will sync results to your dashboard.
- TensorBoard: Local logging with TensorBoard is supported if configured in the YAML file.
The script supports saving and loading from checkpoints, allowing training to be paused and resumed. Checkpoints are saved based on validation metrics, ensuring the best models are preserved.
Several PyTorch Lightning callbacks are utilized to enhance training:
LearningRateMonitor
: Tracks and logs the learning rate.ModelCheckpoint
: Manages the saving of model states.EMA
: Exponential moving average of model parameters for stable training.
- The training script is designed to be modular. Additional models and configurations can be added by updating the
setup_model
function and the corresponding YAML files. - For different datasets, ensure appropriate data modules are available and configured.
- Ensure CUDA devices are properly configured if using GPU acceleration.
- Check Wandb configuration and internet connectivity for remote logging issues.
- Validate data paths and formats to match expected inputs by the models and data loaders.