diff --git a/CHANGELOG.md b/CHANGELOG.md index 916da018..24cb8d30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,27 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. +## Version [0.3.1] - 2024-06-22 + +### Changed + +- Changed `glm_predictive_distribution` so that return a tuple(Normal distribution,fμ, fvar) rather than the tuple (mean,variance). [#90] + +## Version [0.3.0] - 2024-06-21 + +### Changed + +- Changed `glm_predictive_distribution` so that return a Normal distribution rather than the tuple (mean,variance). [#90] +- Changed `predict` so that return directly a Normal distribution in the case of regression. [#90] + +### Added + +- Added functions to compute the average empirical frequency for both classification and regression problems in utils.jl. [#90] + + + + + ## Version [0.2.1] - 2024-05-29 ### Changed diff --git a/Project.toml b/Project.toml index 24355ed3..270fea5f 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.2.1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" @@ -24,6 +25,7 @@ Aqua = "0.8" ChainRulesCore = "1.23.0" Compat = "4.7.0" ComputationalResources = "0.3.2" +Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.6, 1.7, 1.8, 1.9, 1.10" MLJFlux = "0.2.10, 0.3, 0.4" diff --git a/docs/Project.toml b/docs/Project.toml index 49942c85..fedf9ac8 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" @@ -9,5 +10,6 @@ RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TaijaPlotting = "bd7198b4-c7d6-400c-9bab-9a24614b0240" +Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/docs/src/tutorials/calibration.md b/docs/src/tutorials/calibration.md new file mode 100644 index 00000000..19db8ba2 --- /dev/null +++ b/docs/src/tutorials/calibration.md @@ -0,0 +1,9 @@ +# Uncertainty Calibration +## The issue of calibrated uncertainty distributions +Bayesian methods offer a general framework for quantifying uncertainty. However, due to model misspecification and the use of approximate inference, Bayesian uncertainty estimates are often inaccurate: for example, a 90% credible interval may not contain the true outcome 90% of the time. A model is considered calibrated when uncertainty estimates, such as Bayesian credible intervals, accurately reflect the true likelihood of outcomes. In other words, a 90% credible interval is calibrated if it contains the true outcome approximately 90% of the time, thereby indicating the reliability and accuracy of the inference method. In other words, a good forecaster must be calibrated. Perfect calibration + + +## Calibration Plots + + +yadda yadda \ No newline at end of file diff --git a/docs/src/tutorials/logit.md b/docs/src/tutorials/logit.md index e5185187..bee5f5d5 100644 --- a/docs/src/tutorials/logit.md +++ b/docs/src/tutorials/logit.md @@ -81,3 +81,4 @@ p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom) plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400)) ``` +![](logit_files/figure-commonmark/cell-output-1.svg) \ No newline at end of file diff --git a/docs/src/tutorials/regression.md b/docs/src/tutorials/regression.md index eb660591..d2cef2fa 100644 --- a/docs/src/tutorials/regression.md +++ b/docs/src/tutorials/regression.md @@ -132,3 +132,11 @@ plot(la, X, y; zoom=-5, size=(400,400)) Scatter: 8.497215713339543 ![](regression_files/figure-commonmark/cell-7-output-5.svg) + + +## Calibration Plot +Once the prior precision has been optimized it is possible to evaluate the quality of the predictive distribution +obtained through a calibration plot [Link text Here](https://link-url-here.org) [add cross link]. + +![](regression_files/figure-commonmark/miscalibration.svg) + diff --git a/docs/src/tutorials/regression_files/figure-commonmark/miscalibration.svg b/docs/src/tutorials/regression_files/figure-commonmark/miscalibration.svg new file mode 100644 index 00000000..9ca6b50d --- /dev/null +++ b/docs/src/tutorials/regression_files/figure-commonmark/miscalibration.svg @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 6ac5e95d..63f969f1 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -1,4 +1,7 @@ module LaplaceRedux +include("calibration_functions.jl") +export empirical_frequency_binary_classification, + sharpness_classification, empirical_frequency_regression, sharpness_regression include("utils.jl") diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 49773a23..9891e0fb 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -1,3 +1,5 @@ +using Distributions: Distributions +using Statistics: mean, var """ functional_variance(la::AbstractLaplace, 𝐉::AbstractArray) @@ -22,6 +24,8 @@ Computes the linearized GLM predictive. - `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux. - `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux. +- `normal_distr` An array of normal distributions approximating the predictive distribution p(y|X) given the input data X. + # Examples ```julia-repl @@ -39,7 +43,9 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) fμ = reshape(fμ, Flux.outputsize(la.model, size(X))) fvar = functional_variance(la, 𝐉) fvar = reshape(fvar, size(fμ)...) - return fμ, fvar + fstd = sqrt.(fvar) + normal_distr = [Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 2)] + return (normal_distr, fμ, fvar) end """ @@ -55,9 +61,11 @@ Computes predictions from Bayesian neural network. - `predict_proba::Bool=true`: If `true` (default), returns probabilities for classification tasks. # Returns - -- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux. -- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux. +For classification tasks, LaplaceRedux provides different options: + -`fμ::AbstractArray` Mean of the normal distribution if link_approx is set to :plugin + -`fμ::AbstractArray` The probit approximation if link_approx is set to :probit +For regression tasks: +- `normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution. # Examples @@ -75,11 +83,12 @@ predict(la, hcat(x...)) function predict( la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true ) - fμ, fvar = glm_predictive_distribution(la, X) + normal_distr, fμ, fvar = glm_predictive_distribution(la, X) + #fμ, fvar = mean.(normal_distr), var.(normal_distr) # Regression: if la.likelihood == :regression - return fμ, fvar + return normal_distr end # Classification: @@ -95,7 +104,7 @@ function predict( end # Sigmoid/Softmax - if predict_proba + if (predict_proba) if la.posterior.n_out == 1 p = Flux.sigmoid(z) else diff --git a/src/calibration_functions.jl b/src/calibration_functions.jl new file mode 100644 index 00000000..fc7065da --- /dev/null +++ b/src/calibration_functions.jl @@ -0,0 +1,141 @@ +using Statistics +@doc raw""" + empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=20) + +FOR REGRESSION MODELS. \ +Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` and an array of predicted distributions, the function calculates the empirical frequency +```math +p^hat_j = {y_t|F_t(y_t)<= p_j, t= 1,....,T}/T, +``` +where ``T`` is the number of calibration points, ``p_j`` is the confidence level and ``F_t`` is the +cumulative distribution function of the predicted distribution targeting ``y_t``. \ +Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) + +Inputs: \ + - `Y_cal`: a vector of values ``y_t``\ + - `sampled_distributions`:a Vector{Vector{Float64}} of sampled distributions ``F(x_t)`` stacked row-wise.\ + For example [rand(distr,50) for distr in LaplaceRedux.predict(la,X)] + - `n_bins`: number of equally spaced bins to use.\ +Outputs:\ + - `counts`: an array cointaining the empirical frequencies for each quantile interval. +""" +function empirical_frequency_regression(Y_cal, sampled_distributions; n_bins::Int=20) + if n_bins <= 0 + throw(ArgumentError("n_bins must be a positive integer")) + end + n_edges = n_bins + 1 + quantiles = collect(range(0; stop=1, length=n_edges)) + quantiles_matrix = hcat( + [quantile(samples, quantiles) for samples in sampled_distributions]... + ) + n_rows = size(quantiles_matrix, 1) + counts = Float64[] + + for i in 1:n_rows + push!(counts, sum(Y_cal .<= quantiles_matrix[i, :]) / length(Y_cal)) + end + return counts +end + +@doc raw""" + sharpness_regression(sampled_distributions) + +FOR REGRESSION MODELS. \ +Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` and an array of predicted distributions, the function calculates the +sharpness of the predicted distributions, i.e., the average of the variances ``\sigma^2(F_t)`` predicted by the forecaster for each ``x_t``. \ +source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) + +Inputs: \ + - `sampled_distributions`: an array of sampled distributions ``F(x_t)`` stacked column-wise. \ +Outputs: \ + - `sharpness`: a scalar that measure the level of sharpness of the regressor +""" +function sharpness_regression(sampled_distributions) + sharpness = mean(var.(sampled_distributions)) + return sharpness +end + +@doc raw""" + empirical_frequency_classification(y_binary, sampled_distributions) + +FOR BINARY CLASSIFICATION MODELS.\ +Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` let ``p_t= H(x_t)∈[0,1]`` be the forecasted probability. \ +We group the ``p_t`` into intervals ``I_j`` for ``j= 1,2,...,m`` that form a partition of [0,1]. +The function computes the observed average ``p_j= T^-1_j ∑_{t:p_t ∈ I_j} y_j`` in each interval ``I_j``. \ +Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) + +Inputs: \ + - `y_binary`: the array of outputs ``y_t`` numerically coded: 1 for the target class, 0 for the null class. \ + - `sampled_distributions`: an array of sampled distributions stacked column-wise so that in the first row + there is the probability for the target class ``y_1`` and in the second row the probability for the null class ``y_0``. \ + - `n_bins`: number of equally spaced bins to use. + +Outputs: \ + - `num_p_per_interval`: array with the number of probabilities falling within interval. \ + - `emp_avg`: array with the observed empirical average per interval. \ + - `bin_centers`: array with the centers of the bins. + +""" +function empirical_frequency_binary_classification( + y_binary, sampled_distributions; n_bins::Int=20 +) + if n_bins <= 0 + throw(ArgumentError("n_bins must be a positive integer")) + elseif !all(x -> x == 0 || x == 1, y_binary) + throw(ArgumentError("y_binary must be an array of 0 and 1")) + end + #intervals boundaries + n_edges = n_bins + 1 + int_bds = collect(range(0; stop=1, length=n_edges)) + #bin centers + bin_centers = [(int_bds[i] + int_bds[i + 1]) / 2 for i in 1:(length(int_bds) - 1)] + #initialize list for empirical averages per interval + emp_avg = [] + #initialize list for predicted averages per interval + pred_avg = [] + # initialize list of number of probabilities falling within each intervals + num_p_per_interval = [] + #list of the predicted probabilities for the target class + class_probs = sampled_distributions[1, :] + # iterate over the bins + for j in 1:n_bins + push!(num_p_per_interval, sum(int_bds[j] .< class_probs .< int_bds[j + 1])) + if num_p_per_interval[j] == 0 + push!(emp_avg, 0) + push!(pred_avg, bin_centers[j]) + + else + # find the indices fo all istances for which class_probs fall withing the j-th interval + indices = findall(x -> int_bds[j] < x < int_bds[j + 1], class_probs) + #compute the empirical average and saved it in emp_avg in the j-th position + push!(emp_avg, 1 / num_p_per_interval[j] * sum(y_binary[indices])) + #TO DO: maybe substitute to bin_Centers? + push!(pred_avg, 1 / num_p_per_interval[j] * sum(class_probs[indices])) + end + end + #return the tuple + return (num_p_per_interval, emp_avg, bin_centers) +end + +@doc raw""" + sharpness_classification(y_binary,sampled_distributions) + +FOR BINARY CLASSIFICATION MODELS. \ +Assess the sharpness of the model by looking at the distribution of model predictions. +When forecasts are sharp, most predictions are close to either 0 or 1 \ +Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) + +Inputs: \ + - `y_binary` : the array of outputs ``y_t`` numerically coded: 1 for the target class, 0 for the negative result. \ + - `sampled_distributions` : an array of sampled distributions stacked column-wise so that in the first row there is the probability for the target class ``y_1`` and in the second row the probability for the null class ``y_0``. \ + +Outputs: \ + - `mean_class_one` : a scalar that measure the average prediction for the target class \ + - `mean_class_zero` : a scalar that measure the average prediction for the null class + +""" +function sharpness_classification(y_binary, sampled_distributions) + mean_class_one = mean(sampled_distributions[1, findall(y_binary .== 1)]) + mean_class_zero = mean(sampled_distributions[2, findall(y_binary .== 0)]) + return mean_class_one, mean_class_zero +end diff --git a/src/utils.jl b/src/utils.jl index 2bfb59ff..4d2a9194 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,5 @@ using Flux +using Statistics """ get_loss_fun(likelihood::Symbol) diff --git a/test/Project.toml b/test/Project.toml index 7af07ce3..e0aee2a2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/calibration.jl b/test/calibration.jl new file mode 100644 index 00000000..649f0b43 --- /dev/null +++ b/test/calibration.jl @@ -0,0 +1,120 @@ +using Statistics +using LaplaceRedux +using Distributions + +@testset "sharpness_classification tests" begin + + # Test 1: Check that the function runs without errors and returns two scalars for a simple case + y_binary = [1, 0, 1, 0, 1] + sampled_distributions = [0.9 0.1 0.8 0.2 0.7; 0.1 0.9 0.2 0.8 0.3] # Sampled probabilities + mean_class_one, mean_class_zero = sharpness_classification( + y_binary, sampled_distributions + ) + @test typeof(mean_class_one) <: Real # Check if mean_class_one is a scalar + @test typeof(mean_class_zero) <: Real # Check if mean_class_zero is a scalar + + # Test 2: Check the function with a known input + y_binary = [0, 1, 0, 1, 1, 0, 1, 0] + sampled_distributions = [ + 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 + 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 + ] + mean_class_one, mean_class_zero = sharpness_classification( + y_binary, sampled_distributions + ) + @test mean_class_one ≈ mean(sampled_distributions[1, [2, 4, 5, 7]]) + @test mean_class_zero ≈ mean(sampled_distributions[2, [1, 3, 6, 8]]) + + # Test 3: Edge case with all ones in y_binary + y_binary_all_ones = [1, 1, 1] + sampled_distributions_all_ones = [0.8 0.9 0.7; 0.2 0.1 0.3] + mean_class_one_all_ones, mean_class_zero_all_ones = sharpness_classification( + y_binary_all_ones, sampled_distributions_all_ones + ) + @test mean_class_one_all_ones == mean([0.8, 0.9, 0.7]) + @test isnan(mean_class_zero_all_ones) # Since there are no zeros in y_binary, the mean should be NaN + + # Test 4: Edge case with all zeros in y_binary + y_binary_all_zeros = [0, 0, 0] + sampled_distributions_all_zeros = [0.1 0.2 0.3; 0.9 0.8 0.7] + mean_class_one_all_zeros, mean_class_zero_all_zeros = sharpness_classification( + y_binary_all_zeros, sampled_distributions_all_zeros + ) + @test mean_class_zero_all_zeros == mean([0.9, 0.8, 0.7]) + @test isnan(mean_class_one_all_zeros) # Since there are no ones in y_binary, the mean should be NaN +end + +# Test for `sharpness_regression` function +@testset "sharpness_regression tests" begin + + # Test 1: Check that the function runs without errors and returns a scalar for a simple case + sampled_distributions = [randn(100) for _ in 1:10] # Create 10 distributions, each with 100 samples + sharpness = sharpness_regression(sampled_distributions) + @test typeof(sharpness) <: Real # Check if the output is a scalar + + # Test 2: Check the function with a known input + sampled_distributions = [ + [0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3, 0.5], [0.3, 0.4, 0.5, 0.9, 0.2] + ] + mean_variance = mean(map(var, sampled_distributions)) + sharpness = sharpness_regression(sampled_distributions) + @test sharpness ≈ mean_variance + + # Test 3: Edge case with identical distributions + sampled_distributions_identical = [ones(100) for _ in 1:10] # Identical distributions, zero variance + sharpness_identical = sharpness_regression(sampled_distributions_identical) + @test sharpness_identical == 0.0 # Sharpness should be zero for identical distributions +end + +# Test for `empirical_frequency_regression` function +@testset "empirical_frequency_regression tests" begin + # Test 1: Check that the function runs without errors and returns an array for a simple case + Y_cal = [0.5, 1.5, 2.5, 3.5, 4.5] + n_bins = 10 + sampled_distributions = [rand(Distributions.Normal(1, 1.0), 6) for _ in 1:5] + counts = empirical_frequency_regression(Y_cal, sampled_distributions; n_bins=n_bins) + @test typeof(counts) == Array{Float64,1} # Check if the output is an array of Float64 + @test length(counts) == n_bins + 1 + + # Test 2: Check the function with a known input + #to do + + # Test 3: Invalid n_bins input + Y_cal = [0.5, 1.5, 2.5, 3.5, 4.5] + sampled_distributions = [rand(Distributions.Normal(1, 1.0), 6) for _ in 1:5] + @test_throws ArgumentError empirical_frequency_regression( + Y_cal, sampled_distributions, n_bins=0 + ) +end + +# Test for `empirical_frequency_binary_classification` function +@testset "empirical_frequency_binary_classification tests" begin + # Test 1: Check that the function runs without errors and returns an array for a simple case + y_binary = rand(0:1, 10) + sampled_distributions = rand(2, 10) + n_bins = 4 + num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification( + y_binary, sampled_distributions; n_bins=n_bins + ) + @test length(num_p_per_interval) == n_bins + @test length(emp_avg) == n_bins + @test length(bin_centers) == n_bins + + # Test 2: Check the function with a known input + + #to do + + # Test 3: Invalid Y_cal input + Y_cal = [0, 1, 0, 1.2, 4] + sampled_distributions = rand(2, 5) + @test_throws ArgumentError empirical_frequency_binary_classification( + Y_cal, sampled_distributions, n_bins=10 + ) + + # Test 4: Invalid n_bins input + Y_cal = rand(0:1, 5) + sampled_distributions = rand(2, 5) + @test_throws ArgumentError empirical_frequency_binary_classification( + Y_cal, sampled_distributions, n_bins=0 + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5fa82a88..0459cf5a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,9 @@ using Test @testset "Laplace" begin include("laplace.jl") end + @testset "Calibration Plots" begin + include("calibration.jl") + end if VERSION >= v"1.8.0" @testset "PyTorch Comparisons" begin