Skip to content

Commit

Permalink
Merge pull request #219 from JuliaAI/constructor
Browse files Browse the repository at this point in the history
Overload `constructor` trait for `TunedModel` types
  • Loading branch information
ablaom authored Jun 3, 2024
2 parents 2294da4 + b4fbf67 commit c13e844
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.8.6"
version = "0.8.7"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand All @@ -18,7 +18,7 @@ StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
ComputationalResources = "0.3"
Distributions = "0.22,0.23,0.24, 0.25"
LatinHypercubeSampling = "1.7.2"
MLJBase = "1.3"
MLJBase = "1.4"
ProgressMeter = "1.7.1"
RecipesBase = "0.8,0.9,1"
StatisticalMeasuresBase = "0.1.1"
Expand Down
11 changes: 7 additions & 4 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -905,17 +905,19 @@ function MLJBase.feature_importances(::EitherTunedModel, fitresult, report)
return MLJBase.feature_importances(fitresult)
end





## METADATA

MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true
MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} =
MLJBase.supports_weights(M)
MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} =
MLJBase.supports_class_weights(M)
MLJBase.load_path(::Type{<:ProbabilisticTunedModel}) =
"MLJTuning.ProbabilisticTunedModel"
MLJBase.load_path(::Type{<:DeterministicTunedModel}) =
"MLJTuning.DeterministicTunedModel"
MLJBase.load_path(::Type{<:EitherTunedModel}) =
"MLJTuning.TunedModel"
MLJBase.package_name(::Type{<:EitherTunedModel}) = "MLJTuning"
MLJBase.package_uuid(::Type{<:EitherTunedModel}) =
"03970b2e-30c4-11ea-3135-d1576263f10f"
Expand All @@ -928,3 +930,4 @@ MLJBase.input_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} =
MLJBase.input_scitype(M)
MLJBase.target_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} =
MLJBase.target_scitype(M)
MLJBase.constructor(::Type{<:EitherTunedModel}) = TunedModel
2 changes: 2 additions & 0 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ end
TunedModel(first(r), last(r), range=r, measure=l2),
)
tm = @test_logs TunedModel(model=first(r), range=r, measure=l2)
@test MLJBase.constructor(tm) == TunedModel
@test MLJBase.load_path(tm) == "MLJTuning.TunedModel"
@test tm.tuning isa RandomSearch
@test input_scitype(tm) == Table(Continuous)

Expand Down

0 comments on commit c13e844

Please sign in to comment.