Skip to content

Commit

Permalink
Refactor extension folder
Browse files Browse the repository at this point in the history
  • Loading branch information
juliohm committed Nov 26, 2024
1 parent c42e38b commit 7d5c665
Show file tree
Hide file tree
Showing 15 changed files with 33 additions and 74 deletions.
19 changes: 4 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,31 @@ authors = ["Kai Xu <xukai921110@gmail.com>", "Júlio Hoffimann <julio.hoffimann@
version = "1.2.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[extensions]
DensityRatioEstimationChainRulesCoreExt = "ChainRulesCore"
DensityRatioEstimationConvexExt = ["Convex", "ECOS"]
DensityRatioEstimationGPUArraysExt = "GPUArrays"
DensityRatioEstimationJuMPExt = ["JuMP", "Ipopt"]
DensityRatioEstimationOptimExt = "Optim"

[compat]
ChainRulesCore = "1"
ChainRulesCore = "1.25.0"
Convex = "0.15, 0.16"
ECOS = "1"
GPUArrays = "8, 9, 10, 11"
GPUArraysCore = "0.2.0"
Ipopt = "1"
JuMP = "1"
LinearAlgebra = "1.9"
Expand All @@ -40,12 +38,3 @@ Random = "1.9"
Statistics = "1.9"
StatsBase = "0.33, 0.34"
julia = "1.9"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
12 changes: 0 additions & 12 deletions ext/DensityRatioEstimationChainRulesCoreExt.jl

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

module DensityRatioEstimationConvexExt

using DensityRatioEstimation
using DensityRatioEstimation: KLIEP, ConvexLib
using Convex
using ECOS

include("../src/kliep/convex.jl")
import DensityRatioEstimation

include("kliep.jl")

end #module
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

# This file is part of the module DensityRatioEstimationConvexExt.

function DensityRatioEstimation._kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{ConvexLib})
# retrieve parameters
σ, b = dre.σ, size(K_de, 2)
Expand Down
18 changes: 0 additions & 18 deletions ext/DensityRatioEstimationGPUArraysExt.jl

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

module DensityRatioEstimationJuMPExt

using DensityRatioEstimation
using DensityRatioEstimation: LSIF, JuMPLib, AbstractKMM, uKMM, KMM
using DensityRatioEstimation.Parameters
using JuMP
using Ipopt
using LinearAlgebra
using Statistics

include("../src/kmm/jump.jl")
include("../src/lsif/jump.jl")
import DensityRatioEstimation

include("kmm.jl")
include("lsif.jl")

end #module
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

module DensityRatioEstimationOptimExt

using DensityRatioEstimation
using DensityRatioEstimation: KLIEP, LSIF, OptimLib
using LinearAlgebra
using Optim

using LinearAlgebra
import DensityRatioEstimation

include("../src/kliep/optim.jl")
include("../src/lsif/optim.jl")
include("kliep.jl")
include("lsif.jl")

end #module
File renamed without changes.
File renamed without changes.
7 changes: 3 additions & 4 deletions src/DensityRatioEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ using LinearAlgebra
using Parameters
using Random

# implement fit for estimators
using GPUArraysCore: AbstractGPUMatrix
using ChainRulesCore: @non_differentiable

import StatsBase: fit

# API for density ratio estimation
Expand All @@ -29,9 +31,6 @@ include("lsif.jl")
# available estimator fitters
include("lcv.jl")

# pure Julia implementations
include("kmm/julia.jl")

export
# optim libs
OptimizationLibrary,
Expand Down
6 changes: 6 additions & 0 deletions src/kmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,9 @@ default_optlib(dre::Type{<:KMM}) = JuMPLib
available_optlib(dre::Type{<:KMM}) = [JuMPLib]

_kmm_ratios(K, κ, dre::AbstractKMM, optlib::Type{<:OptimizationLibrary}) = _throw_opt_error(dre, optlib)

# pure Julia implementation
function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuliaLib})
# density ratio via solver
K \ vec(κ)
end
8 changes: 0 additions & 8 deletions src/kmm/julia.jl

This file was deleted.

15 changes: 9 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ Return the Euclidean distance between two indexable objects.
"""
euclidsq(x, y) = sum((x[i] - y[i])^2 for i in eachindex(x))

# Support matrix data in a GPU and AD compatible way

"""
euclidsq(X::T, Y::T) where {T<:AbstractMatrix}
Expand Down Expand Up @@ -55,13 +53,18 @@ gaussian_gramian(esq, σ::AbstractFloat) = exp.(-esq ./ 2σ^2)
Generate a squared matrix whose diagonal is `a` that is
compatible to perform addition on `mat`. It behaves
differently based on whether `mat` is on a CPU or GPU.
It is compatible with
- CuArrays.jl (see lib/cuarrays.jl)
- Zygote.jl (see lib/zygote.jl)
"""
safe_diagm(mat, a) = a * I

# avoid `mat + a * I` on GPU which involves scalar operations and is slow
function safe_diagm(mat::AbstractGPUMatrix, a)
diag = similar(mat, size(m, 1))
fill!(diag, a)
Diagonal(diag)
end

@non_differentiable safe_diagm(::Any, ::Any)

###################################################
## Functions and objects for throwing errors ##
###################################################
Expand Down

0 comments on commit 7d5c665

Please sign in to comment.