Lightweight Machine Learning framework allowing plug-and-play training for Pytorch models
- β‘ Lightning inspired
- πΎ Support for wandb and checkpoints out-of-the-box
- π Pretty logs, plots and support for metrics
- β¨ Fully type-safe
- πͺΆ Lightweight and easy to use
See π notebooks for examples using pyro. In particular, you can find:
- Iris : Simplest example training a small MLP on the Iris dataset.
- SmolVLM on Flowers102 : Features from SmolVLM vision model are extracted and used to train a linear classifier on the Flowers102 dataset, reaching a test accuracy of 98.6%.
You can use π₯ pyro with minimal code changes and forever forget about writing training loops. Here is an example of a pyro model and training script to get you started.
import torch
import pyroml as p
class MySOTAModel(p.PyroModule):
def __init__(self):
super().__init__()
self.loss_fn = torch.nn.MyLossFunction()
# Optionally, configure your own optimizer and scheduler, see more in the docs
def configure_optimizers(self, _):
self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
self.scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=1, gamma=0.99
)
def step(self, batch, stage: p.Stage):
# Extract data from your dataset batch
# Batches and model are moved to the appropriate device automatically
x, y = batch
# Forward the model
preds = self(x)
# Compute the loss
loss = self.loss_fn(preds, y)
# Optionally, register some metrics
self.log(loss=loss.item(), accuracy=compute_accuracy(preds, y))
# Return loss when training, otherwise return predictions
if stage == p.Stage.TRAIN:
return loss
return preds
trainer = p.Trainer(
lr=0.01,
max_epochs=32,
batch_size=16,
# And many other options such as device, precision, callbacks, ...
)
# Fit the model on given training set and evaluate the model during training
train_tracker = trainer.fit(model, training_dataset, validation_dataset)
print(train_tracker.records)
# Plot metric curves registered during training
train_tracker.plot(epoch=True)
# Evaluate your model after training
validation_tracker = trainer.evaluate(model, validation_dataset)
print(validation_tracker.records)
# Test your model on some testing set
_, test_preds = trainer.predict(model, test_dataset)
print("Test Predictions", test_preds)
Before proceeding, we suggest taking a quick look at the pyproject.toml
for extra groups that you might want. Below are package manager specific commands to install pyro along with its dependencies.
# For default installation
uv add pyroml
# Add an extra group to target specific pytorch installs, use either cpu or cu124
uv add pyroml --extra (cpu / cu124)
# To also install additional dependencies that you might require
uv add pyroml --extra extra
# CPU only version
poetry add pyroml
# OR with CUDA-enabled PyTorch and torchvision
poetry add pyroml[cuda] --source pytorch-cu124
# To also install additional dependencies that you might require
poetry add [...] --extras extra
# CPU only version
pip install pyroml
# OR with CUDA-enabled PyTorch and torchvision
pip install pyroml[cuda]
# To also install additional dependencies that you might require
pip install pyroml[extra]
# Clone the repo
git clone https://github.com/peacefulotter/pyroml.git
cd pyroml
# For basic installation, add an extra group to target specific pytorch installs
uv sync (--extra cpu / cu124)
# To also install additional packages
uv sync --extra extra
# To install all possibly required packages
uv sync --extra full
# Install dependencies
poetry config virtualenvs.in-project true
poetry install --with cpu,dev # ,cuda
Running tests has been made easy using pytest. First install the package and run the script:
poetry install --with test
./run_tests.sh