Skip to content

add normalize residual score in regression #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -104,6 +104,7 @@ TorchCP has implemented the following methods:
| 2019 | [**Adaptive, Distribution-Free Prediction Intervals for Deep Networks**](https://proceedings.mlr.press/v108/kivaranovic20a.html) | AISTATS'19 | [Link](https://github.com/yromano/cqr) | regression.score.cqrfm | |
| 2019 | [**Conformalized Quantile Regression**](https://proceedings.neurips.cc/paper_files/paper/2019/file/5103c3584b063c431bd1268e9b5e76fb-Paper.pdf) | NeurIPS'19 | [Link](https://github.com/yromano/cqr) | regression.score.cqr | |
| 2017 | [**Distribution-Free Predictive Inference For Regression**](https://arxiv.org/abs/1604.04173) | JASA | [Link](https://github.com/ryantibs/conformal) | regression.predictor.split | |
| 2005 | [**Algorithmic Learning in a Random World**](https://link.springer.com/book/10.1007/b106715) | Springer | | regression.score.abs regression.score.norabs | |

## Graph

57 changes: 57 additions & 0 deletions tests/regression/score/test_norabs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from torchcp.regression.score.norabs import NorABS


@pytest.fixture
def norabs_instance():
return NorABS()


def test_call_with_2d_output(norabs_instance):
predicts = torch.tensor([[0.5, 0.1], [0.4, 0.2]])
y_truth = torch.tensor([0.6, 0.1])

scores = norabs_instance(predicts, y_truth)
expected_scores = torch.tensor([[1.0], [1.5]])
assert torch.allclose(scores, expected_scores), "The __call__ output is incorrect."


def test_call_with_1d_scores(norabs_instance):
# Ensure the unsqueeze works if the score is 1D
predicts = torch.tensor([[0.5, 1.0]])
y_truth = torch.tensor([0.0])
scores = norabs_instance(predicts, y_truth)
assert scores.shape == (1, 1), "The score should be 2D with shape (1, 1)."


def test_generate_intervals_single_threshold(norabs_instance):
predicts = torch.tensor([[0.5, 0.1], [0.4, 0.2]])
q_hat = torch.tensor([2.0])

intervals = norabs_instance.generate_intervals(predicts, q_hat)
expected = torch.tensor([[[0.3, 0.7]], [[0.0, 0.8]]])
assert torch.allclose(intervals, expected), "Interval calculation failed for single threshold."


def test_generate_intervals_multi_threshold(norabs_instance):
predicts = torch.tensor([[0.5, 0.1]])
q_hat = torch.tensor([1.0, 2.0])

intervals = norabs_instance.generate_intervals(predicts, q_hat)
expected = torch.tensor([[[0.4, 0.6], [0.3, 0.7]]])
assert torch.allclose(intervals, expected), "Interval calculation failed for multiple thresholds."


def test_train_returns_model(norabs_instance, dummy_data):
train_dataloader, _ = dummy_data
model = norabs_instance.train(train_dataloader, epochs=2, verbose=True)
model = norabs_instance.train(train_dataloader, epochs=2, verbose=False)
sample_input = next(iter(train_dataloader))[0]

with torch.no_grad():
output = model(sample_input)
assert output.shape[1] == 2, "The trained model should output both mean and variance."
8 changes: 5 additions & 3 deletions torchcp/regression/predictor/split.py
Original file line number Diff line number Diff line change
@@ -50,18 +50,20 @@ def train(self, train_dataloader, **kwargs):
If the train function is not used, users should pass the trained model to the predictor at the beginning.
"""
model = kwargs.pop('model', None)
device = kwargs.pop('device', self._device)
self._device = device

if model is not None:
self._model = self.score_function.train(
train_dataloader, model=model, device=self._device, **kwargs
train_dataloader, model=model, device=device, **kwargs
)
elif self._model is not None:
self._model = self.score_function.train(
train_dataloader, model=self._model, device=self._device, **kwargs
train_dataloader, model=self._model, device=device, **kwargs
)
else:
self._model = self.score_function.train(
train_dataloader, device=self._device, **kwargs
train_dataloader, device=device, **kwargs
)

def calculate_score(self, predicts, y_truth):
1 change: 1 addition & 0 deletions torchcp/regression/score/__init__.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
#

from .abs import ABS
from .norabs import NorABS
from .cqr import CQR
from .cqrfm import CQRFM
from .cqrm import CQRM
7 changes: 3 additions & 4 deletions torchcp/regression/score/abs.py
Original file line number Diff line number Diff line change
@@ -15,15 +15,14 @@

class ABS(BaseScore):
"""
Split Conformal Prediction for Regression.
Absolute value of the difference between prediction and true value.

This score function allows for calculating scores and generating prediction intervals
using a single-point regression model.

Reference:
Paper: Distribution-Free Predictive Inference For Regression (Lei et al., 2017)
Link: https://arxiv.org/abs/1604.04173
Github: https://github.com/ryantibs/conformal
Book: Algorithmic Learning in a Random World (Vovk et al., 2005)
Link: https://link.springer.com/book/10.1007/b106715
"""

def __init__(self):
12 changes: 10 additions & 2 deletions torchcp/regression/score/base.py
Original file line number Diff line number Diff line change
@@ -76,7 +76,11 @@ def _basetrain(self, model, epochs, train_dataloader, criterion, optimizer, verb
running_loss = 0.0
for index, (tmp_x, tmp_y) in enumerate(train_dataloader):
outputs = model(tmp_x.to(device))
loss = criterion(outputs, tmp_y.reshape(-1, 1).to(device))
if criterion.__class__.__name__ == 'GaussianNLLLoss':
mu, var = outputs[..., 0], outputs[..., 1]
loss = criterion(mu, tmp_y.reshape(-1, 1).to(device), var)
else:
loss = criterion(outputs, tmp_y.reshape(-1, 1).to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
@@ -87,7 +91,11 @@ def _basetrain(self, model, epochs, train_dataloader, criterion, optimizer, verb
else:
for tmp_x, tmp_y in train_dataloader:
outputs = model(tmp_x.to(device))
loss = criterion(outputs, tmp_y.reshape(-1, 1).to(device))
if criterion.__class__.__name__ == 'GaussianNLLLoss':
mu, var = outputs[..., 0], outputs[..., 1]
loss = criterion(mu, tmp_y.reshape(-1, 1).to(device), var)
else:
loss = criterion(outputs, tmp_y.reshape(-1, 1).to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
100 changes: 100 additions & 0 deletions torchcp/regression/score/norabs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import torch.nn as nn
import torch.optim as optim

from torchcp.regression.score.abs import ABS
from torchcp.regression.utils import build_regression_model


class NorABS(ABS):
"""
Normalized Absolute Score (NorABS) used for conformal regression.

This score function computes the absolute difference between the predicted mean and
the ground truth value, normalized by the predicted standard deviation. It is designed
for use with probabilistic regression models that predict both the mean and variance.

Reference:
Book: Algorithmic Learning in a Random World (Vovk et al., 2005)
Link: https://link.springer.com/book/10.1007/b106715
"""

def __init__(self):
super().__init__()

def __call__(self, predicts, y_truth):
"""
Computes the normalized non-conformity score for conformal prediction.

Args:
predicts (torch.Tensor): Tensor containing predicted mean and standard deviation,
shape (batch_size, 2), where the first column is mean (μ)
and the second column is standard deviation (σ).
y_truth (torch.Tensor): Tensor of true target values, shape (batch_size,).

Returns:
torch.Tensor: Tensor of normalized absolute deviations, shape (batch_size, 1).
"""
mu, var = predicts[..., 0], predicts[..., 1]
scores = torch.abs(mu - y_truth) / var
if len(scores.shape) == 1:
scores = scores.unsqueeze(1)
return scores

def generate_intervals(self, predicts_batch, q_hat):
"""
Generates prediction intervals using predicted means and standard deviations,
scaled by the calibrated threshold :attr:`q_hat`.

Args:
predicts_batch (torch.Tensor): Tensor of predicted (mean, std), shape (batch_size, 2).
q_hat (torch.Tensor): Calibrated threshold values, shape (num_thresholds,).

Returns:
torch.Tensor: Prediction intervals, shape (batch_size, num_thresholds, 2),
where the last dimension contains lower and upper bounds.
"""
if len(predicts_batch.shape) == 2:
predicts_batch = predicts_batch.unsqueeze(1)
prediction_intervals = predicts_batch.new_zeros((predicts_batch.shape[0], q_hat.shape[0], 2))
prediction_intervals[..., 0] = predicts_batch[..., 0] - q_hat.view(1, q_hat.shape[0]) * predicts_batch[..., 1]
prediction_intervals[..., 1] = predicts_batch[..., 0] + q_hat.view(1, q_hat.shape[0]) * predicts_batch[..., 1]
return prediction_intervals

def train(self, train_dataloader, **kwargs):
"""
Trains the probabilistic regression model to predict both mean and variance.

Args:
train_dataloader (DataLoader): DataLoader for the training data.
**kwargs: Additional keyword arguments for training configuration.
- model (nn.Module, optional): Custom regression model. If None, defaults to
GaussianRegressionModel.
- epochs (int, optional): Number of training epochs. Defaults to 100.
- criterion (nn.Module, optional): Loss function. Defaults to GaussianNLLLoss.
- lr (float, optional): Learning rate. Defaults to 0.01.
- optimizer (torch.optim.Optimizer, optional): Optimizer. Defaults to Adam.
- verbose (bool, optional): Whether to print training progress. Defaults to True.

Returns:
nn.Module: The trained regression model.
"""
device = kwargs.get('device', None)
model = kwargs.get('model',
build_regression_model("GaussianRegressionModel")(next(iter(train_dataloader))[0].shape[1], 64,
0.5).to(device))
epochs = kwargs.get('epochs', 100)
criterion = kwargs.get('criterion', nn.GaussianNLLLoss())
lr = kwargs.get('lr', 0.01)
optimizer = kwargs.get('optimizer', optim.Adam(model.parameters(), lr=lr))
verbose = kwargs.get('verbose', True)

self._basetrain(model, epochs, train_dataloader, criterion, optimizer, verbose)
return model
26 changes: 26 additions & 0 deletions torchcp/regression/utils/model.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,9 @@
# LICENSE file in the root directory of this source tree.
#

import torch
import torch.nn as nn
import torch.nn.functional as F


class NonLinearNet(nn.Module):
@@ -39,12 +41,36 @@ def __init__(self, input_dim, output_dim, hidden_size, dropout):

def forward(self, x):
return self.base_model(x)


class GaussianRegressionModel(nn.Module):
def __init__(self, input_dim, hidden_dim=64, dropout=0.5):
super(GaussianRegressionModel, self).__init__()
self.shared = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
self.output_layer = nn.Linear(hidden_dim, 2)

def forward(self, x):
x = self.shared(x)
out = self.output_layer(x)
mu = out[..., 0]
var = F.softplus(out[..., 1]) + 1e-6
return torch.stack([mu, var], dim=-1)



def build_regression_model(model_name="NonLinearNet"):
if model_name == "NonLinearNet":
return NonLinearNet
elif model_name == "NonLinearNet_with_Softmax":
return Softmax
elif model_name == 'GaussianRegressionModel':
return GaussianRegressionModel
else:
raise NotImplementedError