Skip to content

Commit

Permalink
Merge #227
Browse files Browse the repository at this point in the history
227: Low rank normalization r=odunbar a=odunbar

<!--- THESE LINES ARE COMMENTED -->
## Purpose 
<!--- One sentence to describe the purpose of this PR, refer to any linked issues:
#14 -- this will link to issue 14
Closes #2 -- this will automatically close issue 2 on PR merge
-->
Closes #225

## Content
<!---  specific tasks that are currently complete 
- Solution implemented
-->
- Based on the idea that the covariance defines geometrically a linear transformation from white noise to the data noise, ensemble covariance ``C`` can be viewed as ``C = R S S R^{-1}`` where ``R`` is a rotation, and ``S`` is the sqrt of the singular values of ``C``.  The transformation from white data to the actual data samples is ``C^{1/2} = RS`` and it's "whitening" inverse is ``C^{-1/2} = S^{-1}R^{-1} = S^{-1}R^T`` as ``R`` is a rotation. For rank deficient ``C``, we take the normalization to be instead ``Sinv R^T`` where ``Sinv`` is Diagonal with the first ``rank(C)`` entries equal to ``S^{-1}``, and zero otherwise. 
- Added tests to check the covariances become close to the identity in full or low-dim subspace after normalization
- Changed an input-dim consistency check to account for new dimension change

<!---
Review checklist

I have:
- followed the codebase contribution guide: https://clima.github.io/ClimateMachine.jl/latest/Contributing/
- followed the style guide: https://clima.github.io/ClimateMachine.jl/latest/DevDocs/CodeStyle/
- followed the documentation policy: https://github.com/CliMA/policies/wiki/Documentation-Policy
- checked that this PR does not duplicate an open PR.

In the Content, I have included 
- relevant unit tests, and integration tests, 
- appropriate docstrings on all functions, structs, and modules, and included relevant documentation.

-->

----
- [ ] I have read and checked the items on the review checklist.


Co-authored-by: odunbar <odunbar@caltech.edu>
  • Loading branch information
bors[bot] and odunbar authored Jul 21, 2023
2 parents 701aef3 + a05fc58 commit 89724aa
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 39 deletions.
83 changes: 51 additions & 32 deletions src/Emulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Random

export Emulator

export calculate_normalization
export build_models!
export optimize_hyperparameters!
export predict
Expand Down Expand Up @@ -63,11 +64,11 @@ struct Emulator{FT <: AbstractFloat}
training_pairs::PairedDataContainer{FT}
"Mean of input; length *input\\_dim*."
input_mean::AbstractVector{FT}
"Square root of the inverse of the input covariance matrix; size *input\\_dim* × *input\\_dim*."
"If normalizing: whether to fit models on normalized inputs (`(inputs - input_mean) * sqrt_inv_input_cov`)."
normalize_inputs::Bool
"(Linear) normalization transformation; size *input\\_dim* × *input\\_dim*."
normalization::Union{AbstractMatrix{FT}, UniformScaling{FT}, Nothing}
"Whether to fit models on normalized outputs: `outputs / standardize_outputs_factor`."
sqrt_inv_input_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}, Nothing}
"If normalizing: whether to fit models on normalized inputs (`(inputs - input_mean) * sqrt_inv_input_cov`)."
standardize_outputs::Bool
"If standardizing: Standardization factors (characteristic values of the problem)."
standardize_outputs_factors::Union{AbstractVector{FT}, Nothing}
Expand Down Expand Up @@ -106,11 +107,11 @@ function Emulator(

# [1.] Normalize the inputs?
input_mean = vec(mean(get_inputs(input_output_pairs), dims = 2)) #column vector
sqrt_inv_input_cov = nothing
normalization = nothing
if normalize_inputs
# Normalize (NB the inputs have to be of) size [input_dim × N_samples] to pass to ML tool
sqrt_inv_input_cov = sqrt(inv(Symmetric(cov(get_inputs(input_output_pairs), dims = 2))))
training_inputs = normalize(get_inputs(input_output_pairs), input_mean, sqrt_inv_input_cov)
normalization = calculate_normalization(get_inputs(input_output_pairs))
training_inputs = normalize(get_inputs(input_output_pairs), input_mean, normalization)
# new input_dim < input_dim when inputs lie in a proper linear subspace.
else
training_inputs = get_inputs(input_output_pairs)
end
Expand All @@ -130,14 +131,7 @@ function Emulator(
decorrelated_training_outputs, decomposition =
svd_transform(training_outputs, obs_noise_cov, retained_svd_frac = retained_svd_frac)

# write new pairs structure
if retained_svd_frac < 1.0
#note this changes the dimension of the outputs
training_pairs = PairedDataContainer(training_inputs, decorrelated_training_outputs)
input_dim, output_dim = size(training_pairs, 1)
else
training_pairs = PairedDataContainer(training_inputs, decorrelated_training_outputs)
end
training_pairs = PairedDataContainer(training_inputs, decorrelated_training_outputs)

# [4.] build an emulator
build_models!(machine_learning_tool, training_pairs)
Expand All @@ -158,7 +152,7 @@ function Emulator(
training_pairs,
input_mean,
normalize_inputs,
sqrt_inv_input_cov,
normalization,
standardize_outputs,
standardize_outputs_factors,
decomposition,
Expand Down Expand Up @@ -194,8 +188,21 @@ function predict(
input_dim, output_dim = size(emulator.training_pairs, 1)

N_samples = size(new_inputs, 2)
size(new_inputs, 1) == input_dim ||
throw(ArgumentError("Emulator object and input observations do not have consistent dimensions"))

# check sizing against normalization
if emulator.normalize_inputs
size(new_inputs, 1) == size(emulator.normalization, 2) || throw(
ArgumentError(
"Emulator object and input observations do not have consistent dimensions, expected $(size(emulator.normalization,2)), received $(size(new_inputs,1))",
),
)
else
size(new_inputs, 1) == input_dim || throw(
ArgumentError(
"Emulator object and input observations do not have consistent dimensions, expected $(input_dim), received $(size(new_inputs,1))",
),
)
end

# [1.] normalize
normalized_new_inputs = normalize(emulator, new_inputs)
Expand Down Expand Up @@ -272,30 +279,42 @@ end
"""
$(DocStringExtensions.TYPEDSIGNATURES)
Normalize the input data, with a normalizing function.
Calculate the normalization of inputs.
"""
function normalize(emulator::Emulator{FT}, inputs::AbstractVecOrMat{FT}) where {FT <: AbstractFloat}
if emulator.normalize_inputs
return normalize(inputs, emulator.input_mean, emulator.sqrt_inv_input_cov)
else
return inputs
function calculate_normalization(inputs::VOrM) where {VOrM <: AbstractVecOrMat}
input_mean = vec(mean(inputs, dims = 2))
input_cov = cov(inputs, dims = 2)

if rank(input_cov) == size(input_cov, 1)
normalization = sqrt(inv(input_cov))
else # if not full rank, normalize non-zero singular values
svd_in = svd(input_cov)
sqrt_inv_sv = 1 ./ sqrt.(svd_in.S[1:rank(input_cov)])
normalization = Diagonal(sqrt_inv_sv) * svd_in.Vt[1:rank(input_cov), :] #non-square
end
return normalization
end

"""
$(DocStringExtensions.TYPEDSIGNATURES)
Normalize with the empirical Gaussian distribution of points.
Normalize the input data, with a normalizing function.
"""
function normalize(
inputs::AbstractVecOrMat{FT},
input_mean::AbstractVector{FT},
sqrt_inv_input_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
) where {FT <: AbstractFloat}
training_inputs = sqrt_inv_input_cov * (inputs .- input_mean)
return training_inputs
function normalize(emulator::Emulator, inputs::VOrM) where {VOrM <: AbstractVecOrMat}
if emulator.normalize_inputs
return normalize(inputs, emulator.input_mean, emulator.normalization)
else
return inputs
end
end

function normalize(
inputs::VOrM,
input_mean::V,
normalization::M,
) where {VOrM <: AbstractVecOrMat, V <: AbstractVector, M <: AbstractMatrix}
return normalization * (inputs .- input_mean)
end
"""
$(DocStringExtensions.TYPEDSIGNATURES)
Expand Down
40 changes: 33 additions & 7 deletions test/Emulator/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ function constructor_tests(
norm_factors,
decomposition,
) where {FT <: AbstractFloat}
# [2.] test Normalization
input_mean = vec(mean(get_inputs(iopairs), dims = 2)) #column vector
sqrt_inv_input_cov = sqrt(inv(Symmetric(cov(get_inputs(iopairs), dims = 2))))
norm_inputs = Emulators.normalize(get_inputs(iopairs), input_mean, sqrt_inv_input_cov)
@test norm_inputs == sqrt_inv_input_cov * (get_inputs(iopairs) .- input_mean)

input_mean = vec(mean(get_inputs(iopairs), dims = 2)) #column vector
normalization = sqrt(inv(Symmetric(cov(get_inputs(iopairs), dims = 2))))
normalization = Emulators.calculate_normalization(get_inputs(iopairs))
norm_inputs = Emulators.normalize(get_inputs(iopairs), input_mean, normalization)
# [4.] test emulator preserves the structures
mlt = MLTester()
@test_throws ErrorException emulator = Emulator(
Expand Down Expand Up @@ -89,8 +88,9 @@ end
#build some quick data + noise
m = 50
d = 6
x = rand(3, m) #R^3
y = rand(d, m) #R^5
p = 10
x = rand(p, m) #R^3
y = rand(d, m) #R^6

# "noise"
μ = zeros(d)
Expand Down Expand Up @@ -130,12 +130,38 @@ end
@test y_new y[:, 1]
@test y_cov_new[1] Σ


# Truncation
transformed_y, trunc_decomposition = Emulators.svd_transform(y[:, 1], Σ, retained_svd_frac = 0.95)
trunc_size = size(trunc_decomposition.S)[1]
@test test_SVD.S[1:trunc_size] == trunc_decomposition.S
@test size(transformed_y)[1] == trunc_size

# [2.] test Normalization
# full rank
input_mean = vec(mean(get_inputs(iopairs), dims = 2)) #column vector
normalization = sqrt(inv(Symmetric(cov(get_inputs(iopairs), dims = 2))))
@test all(isapprox.(Emulators.calculate_normalization(get_inputs(iopairs)), normalization, atol = 1e-12))


norm_inputs = Emulators.normalize(get_inputs(iopairs), input_mean, normalization)
@test all(isapprox.(norm_inputs, normalization * (get_inputs(iopairs) .- input_mean), atol = 1e-12))
@test isapprox.(norm(cov(norm_inputs, dims = 2) - I), 0.0, atol = 1e-8)

# reduced rank
reduced_inputs = get_inputs(iopairs)[:, 1:(p - 1)]
input_mean = vec(mean(reduced_inputs, dims = 2)) #column vector
input_cov = cov(reduced_inputs, dims = 2)
r = rank(input_cov) # = p-2
svd_in = svd(input_cov)
sqrt_inv_sv = 1 ./ sqrt.(svd_in.S[1:r])
normalization = Diagonal(sqrt_inv_sv) * svd_in.Vt[1:r, :] # size r x p
@test all(isapprox.(Emulators.calculate_normalization(reduced_inputs), normalization, atol = 1e-12))

norm_inputs = Emulators.normalize(reduced_inputs, input_mean, normalization)
@test size(norm_inputs) == (r, p - 1)
@test isapprox.(norm(cov(norm_inputs, dims = 2)[1:r, 1:r] - I(r)), 0.0, atol = 1e-12)

# [3.] test Standardization
norm_factors = 10.0
norm_factors = fill(norm_factors, size(y[:, 1])) # must be size of output dim
Expand Down

0 comments on commit 89724aa

Please sign in to comment.