From f2c7d9738db9de7086d7180324e06bfad361bc07 Mon Sep 17 00:00:00 2001 From: getzze Date: Tue, 25 Jul 2023 14:27:02 +0100 Subject: [PATCH] Enforce type consistency (#43) * enforce type consistency --- src/robustlinearmodel.jl | 28 ++++++++++++++++------------ src/tools.jl | 11 +++++++++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/robustlinearmodel.jl b/src/robustlinearmodel.jl index 86ee199..a0644db 100644 --- a/src/robustlinearmodel.jl +++ b/src/robustlinearmodel.jl @@ -98,8 +98,8 @@ function StatsAPI.fit( end # Make sure X and y have the same float eltype - T = promote_type(float(eltype(X)), float(eltype(y))) - return fit(M, convert.(T, X), convert.(T, y), args...; kwargs...) + pX, py = promote_to_same_float(X, y) + return fit(M, pX, py, args...; kwargs...) end ## Convert from formula-data to modelmatrix-response calling form @@ -119,7 +119,10 @@ function StatsAPI.fit( # Extract arrays from data using formula f, y, X, extra = modelframe(f, data, contrasts, dropmissing, M; wts=wts) # Call the `fit` method with arrays - return fit(M, X, y, args...; wts=extra.wts, contrasts=contrasts, __formula=f, kwargs...) + pX, py = promote_to_same_float(X, y) + return fit( + M, pX, py, args...; wts=extra.wts, contrasts=contrasts, __formula=f, kwargs... + ) end @@ -1021,11 +1024,11 @@ function pirls!( devold = deviance(m) absdev = abs(devold) dev = devold - Δdev = 0 + Δdev = zero(T) verbose && println("initial deviance: $(@sprintf("%.4g", devold))") for i in 1:maxiter - f = 1.0 # line search factor + f = one(T) # line search factor # local dev absdev = abs(devold) @@ -1124,12 +1127,12 @@ function pirls_Sestimate!( sigold = scale( setη!(m; updatescale=true, verbose=verbose, sigma0=sigma0, fallback=maxσ) ) - installbeta!(p, 1) + installbeta!(p, one(T)) r.σ = sigold verbose && println("initial iteration scale: $(@sprintf("%.4g", sigold))") for i in 1:maxiter - f = 1.0 # line search factor + f = one(T) # line search factor local sig # Compute the change to β, update μ and compute deviance @@ -1242,11 +1245,11 @@ function pirls_τestimate!( # Compute initial τ-scale tauold = tauscale(setη!(m; updatescale=true); verbose=verbose) - installbeta!(p, 1) + installbeta!(p, one(T)) verbose && println("initial iteration τ-scale: $(@sprintf("%.4g", tauold))") for i in 1:maxiter - f = 1.0 # line search factor + f = one(T) # line search factor local tau # Compute the change to β, update μ and compute deviance @@ -1366,6 +1369,7 @@ function resampling_best_estimate( ## Hubert2015 - The DetS and DetMM estimators for multivariate location and scatter ## (https://www.sciencedirect.com/science/article/abs/pii/S0167947314002175) M = length(coef(m)) + T = eltype(coef(m)) if isnothing(Nsamples) Nsamples = resampling_minN(M, 0.05, propoutliers) @@ -1377,8 +1381,8 @@ function resampling_best_estimate( verbose && println("Start $(Nsamples) subsamples...") - σis = zeros(Nsamples) - βis = zeros(M, Nsamples) + σis = zeros(T, Nsamples) + βis = zeros(T, M, Nsamples) for i in 1:Nsamples # TODO: to parallelize, make a deepcopy of m inds = sample(rng, axes(response(m), 1), Npoints; replace=false, ordered=false) @@ -1393,7 +1397,7 @@ function resampling_best_estimate( # Initialize σ as mad(residuals) setinitσ!(m) - σi = 0 + σi = zero(T) for k in 1:Nsteps_β setη!( m; diff --git a/src/tools.jl b/src/tools.jl index cc99859..5fe133a 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -4,6 +4,17 @@ ## Missing values ################################################ +function promote_to_same_float(X::AbstractMatrix, y::AbstractVector) + T = promote_type(float(eltype(X)), float(eltype(y))) + if !(T <: AbstractFloat) + msg = "promoting X and y arrays to float types" + throw(TypeError(:fit, msg, Type{<:AbstractFloat}, T)) + end + MT = AbstractMatrix{T} + VT = AbstractVector{T} + return convert.(T, X)::MT, convert.(T, y)::VT +end + _missing_omit(x::AbstractArray{T}) where {T} = copyto!(similar(x, nonmissingtype(T)), x) function StatsModels.missing_omit(X::AbstractMatrix, y::AbstractVector)