diff --git a/Project.toml b/Project.toml index bf446bc..247952f 100644 --- a/Project.toml +++ b/Project.toml @@ -18,8 +18,10 @@ julia = "1.6" MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"] +test = ["Test", "MLJModels", "MLJTuning", "MLFlowClient", + "StatisticalMeasures", "MLJDecisionTreeInterface"] diff --git a/src/MLJFlow.jl b/src/MLJFlow.jl index f42d1b6..61a59fa 100644 --- a/src/MLJFlow.jl +++ b/src/MLJFlow.jl @@ -1,10 +1,9 @@ module MLJFlow -using MLJBase: Model, Machine, name +using MLJBase: Model, Machine, name using MLJModelInterface: flat_params -using MLFlowClient: MLFlow, logparam, logmetric, - createrun, MLFlowRun, updaterun, logartifact, - getorcreateexperiment +using MLFlowClient: MLFlow, logparam, logmetric, createrun, MLFlowRun, + updaterun, logartifact, getorcreateexperiment import Base: show import MLJBase: save, log_evaluation diff --git a/src/base.jl b/src/base.jl index 82413a0..977c1be 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,4 +1,4 @@ -function log_evaluation(logger::Logger, performance_evaluation) +function _log_evaluation(logger::Logger, performance_evaluation) experiment = getorcreateexperiment(logger.service, logger.experiment_name; artifact_location=logger.artifact_location) run = createrun(logger.service, experiment; @@ -19,6 +19,15 @@ function log_evaluation(logger::Logger, performance_evaluation) updaterun(logger.service, run, "FINISHED") end +function log_evaluation(logger::Logger, performance_evaluation) + result_channel = Channel{MLFlowRun}(1) + + put!(logger._logging_channel, (_log_evaluation, logger, performance_evaluation, result_channel)) + wait(result_channel) + + return take!(result_channel) +end + function save(logger::Logger, machine:: Machine) io = IOBuffer() save(io, machine) diff --git a/src/types.jl b/src/types.jl index a491479..5082190 100644 --- a/src/types.jl +++ b/src/types.jl @@ -28,13 +28,17 @@ struct Logger verbosity::Int experiment_name::String artifact_location::Union{String,Nothing} + _logging_channel::Channel{Tuple} end + function Logger(apiroot; experiment_name="MLJ experiment", artifact_location=nothing, verbosity=1) service = MLFlow(apiroot) + logging_channel = open_logging_channel() - Logger(service, verbosity, experiment_name, artifact_location) + Logger(service, verbosity, experiment_name, artifact_location, logging_channel) end + function show(io::IO, logger::MLJFlow.Logger) print(io, "MLFLowLogger(\"$(logger.service.apiroot)\",\n" * @@ -43,3 +47,39 @@ function show(io::IO, logger::MLJFlow.Logger) ") using MLFlow API version $(logger.service.apiversion)" ) end + +""" + close(logger::Logger) + +Each logger instance has a background loop that allows to execute the logging +operations from the `_logging_channel`. This function closes the channel +to stop the background loop. +""" +function close(logger::Logger) + close(logger._logging_channel) +end + +""" + open_logging_channel(logger::Logger) + +To allow safe concurrent logging operations, this function opens the +`_logging_channel` of the logger and starts a background worker. +""" +function open_logging_channel() + logging_channel = Channel{Tuple}() + + # NOTE: This background loop allows to execute the logging operations from + # the logging_channel. The execution result is sent back to the + # requesting thread through the result_channel. + # Until May 2024, mlflow does not support concurrent experiment creation, + # which does not allow to run the logging operations in multi-threading and + # multi-processing. + # + # Its usage can be seen in the `log_evaluation` function in `base.jl`. + Threads.@spawn for (logging_function, logger, performance_evaluation, result_channel) in logging_channel + result = logging_function(logger, performance_evaluation) + put!(result_channel, result) + end + + return logging_channel +end diff --git a/test/base.jl b/test/base.jl index 02d3f00..6140474 100644 --- a/test/base.jl +++ b/test/base.jl @@ -12,8 +12,8 @@ measures=[LogLoss(), Accuracy()], verbosity=1, logger=logger) @testset "log_evaluation" begin - runs = searchruns(logger.service, - getexperiment(logger.service, logger.experiment_name)) + experiment = getexperiment(logger.service, logger.experiment_name) + runs = searchruns(logger.service, experiment) @test typeof(runs[1]) == MLFlowRun end diff --git a/test/multiprocessing.jl b/test/multiprocessing.jl new file mode 100644 index 0000000..4076a19 --- /dev/null +++ b/test/multiprocessing.jl @@ -0,0 +1,38 @@ +@testset verbose = true "multiprocessing" begin + logger = MLJFlow.Logger(ENV["MLFLOW_TRACKING_URI"]; + experiment_name="MLJFlow multiprocessing tests", + artifact_location="/tmp/mlj-test") + + X, y = make_moons(100) + DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree + + model = DecisionTreeClassifier() + r = range(model, :max_depth, lower=1, upper=6) + + function test_tuned_model(acceleration_method) + tuned_model = TunedModel( + model=model, + range=r, + logger=logger, + acceleration=acceleration_method, + n=100, + ) + tuned_model_mach = machine(tuned_model, X, y) + fit!(tuned_model_mach) + + experiment = getorcreateexperiment(logger.service, logger.experiment_name) + runs = searchruns(logger.service, experiment) + + @assert length(runs) == 100 + + deleteexperiment(logger.service, experiment) + end + + @testset "log_evaluation_with_cpu_threads" begin + test_tuned_model(CPUThreads()) + end + + @testset "log_evaluation_with_cpu_processes" begin + test_tuned_model(CPUProcesses()) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index fd090ce..cbe8a7a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using MLJFlow using MLJBase using MLJModels +using MLJTuning using MLFlowClient using MLJModelInterface using StatisticalMeasures @@ -21,4 +22,4 @@ end include("base.jl") include("types.jl") include("service.jl") - +include("multiprocessing.jl")