diff --git a/Project.toml b/Project.toml index 2a6c764..a7df60d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJTuning" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" authors = ["Anthony D. Blaom "] -version = "0.8.7" +version = "0.8.8" [deps] ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" @@ -18,7 +18,7 @@ StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" ComputationalResources = "0.3" Distributions = "0.22,0.23,0.24, 0.25" LatinHypercubeSampling = "1.7.2" -MLJBase = "1.4" +MLJBase = "1.5" ProgressMeter = "1.7.1" RecipesBase = "0.8,0.9,1" StatisticalMeasuresBase = "0.1.1" diff --git a/src/tuned_models.jl b/src/tuned_models.jl index 70a35a9..b7add8d 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -257,6 +257,10 @@ key | value regular [`PerformanceEvaluation`](@ref) objects to the history (accessed via the `:evaluation` key); the compact form excludes some fields to conserve memory. +- `logger=default_logger()`: a logger for externally reporting model performance + evaluations, such as an `MLJFlow.Logger` instance. On startup, + `default_logger()=nothing`; use `default_logger(logger)` to set a global logger. + """ function TunedModel( args...; @@ -281,7 +285,7 @@ function TunedModel( check_measure=true, cache=true, compact_history=true, - logger=nothing + logger=MLJBase.default_logger() ) # user can specify model as argument instead of kwarg: diff --git a/test/tuned_models.jl b/test/tuned_models.jl index f99f799..cab2c8e 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -526,4 +526,29 @@ end @test first(evaluations) isa MLJBase.PerformanceEvaluation end +struct DummyLogger + buffer +end + +MLJBase.log_evaluation(logger::DummyLogger, performance_evaluation) = + write(logger.buffer, performance_evaluation.measurement[1]) + +@testset "default logger" begin + buffer = IOBuffer() + logger = DummyLogger(buffer) + default_logger(logger) + model1 = KNNRegressor(K=5) + model2 = KNNRegressor(K=3) + tmodel = TunedModel(models=[model1, model2], measure=l2) + mach = machine(tmodel, make_regression(10)...) + fit!(mach, verbosity=0) + seekstart(buffer) + @test all(report(mach).history) do entry + logger_measurement = read(buffer, Float64) + logger_measurement == entry.evaluation.measurement[1] + end + default_logger(nothing) + close(buffer) +end + true