Skip to content

Commit

Permalink
Merge pull request #41 from JuliaAI/implementing-channel-control-for-…
Browse files Browse the repository at this point in the history
…experiment-creation

Implementing channel control for experiment creation
  • Loading branch information
pebeto authored May 21, 2024
2 parents c1076e9 + f4d3e30 commit 8a10ac3
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 10 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ julia = "1.6"
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"]
test = ["Test", "MLJModels", "MLJTuning", "MLFlowClient",
"StatisticalMeasures", "MLJDecisionTreeInterface"]
7 changes: 3 additions & 4 deletions src/MLJFlow.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
module MLJFlow

using MLJBase: Model, Machine, name
using MLJBase: Model, Machine, name
using MLJModelInterface: flat_params
using MLFlowClient: MLFlow, logparam, logmetric,
createrun, MLFlowRun, updaterun, logartifact,
getorcreateexperiment
using MLFlowClient: MLFlow, logparam, logmetric, createrun, MLFlowRun,
updaterun, logartifact, getorcreateexperiment

import Base: show
import MLJBase: save, log_evaluation
Expand Down
11 changes: 10 additions & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function log_evaluation(logger::Logger, performance_evaluation)
function _log_evaluation(logger::Logger, performance_evaluation)
experiment = getorcreateexperiment(logger.service, logger.experiment_name;
artifact_location=logger.artifact_location)
run = createrun(logger.service, experiment;
Expand All @@ -19,6 +19,15 @@ function log_evaluation(logger::Logger, performance_evaluation)
updaterun(logger.service, run, "FINISHED")
end

function log_evaluation(logger::Logger, performance_evaluation)
result_channel = Channel{MLFlowRun}(1)

put!(logger._logging_channel, (_log_evaluation, logger, performance_evaluation, result_channel))
wait(result_channel)

return take!(result_channel)
end

function save(logger::Logger, machine:: Machine)
io = IOBuffer()
save(io, machine)
Expand Down
42 changes: 41 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ struct Logger
verbosity::Int
experiment_name::String
artifact_location::Union{String,Nothing}
_logging_channel::Channel{Tuple}
end

function Logger(apiroot; experiment_name="MLJ experiment",
artifact_location=nothing, verbosity=1)
service = MLFlow(apiroot)
logging_channel = open_logging_channel()

Logger(service, verbosity, experiment_name, artifact_location)
Logger(service, verbosity, experiment_name, artifact_location, logging_channel)
end

function show(io::IO, logger::MLJFlow.Logger)
print(io,
"MLFLowLogger(\"$(logger.service.apiroot)\",\n" *
Expand All @@ -43,3 +47,39 @@ function show(io::IO, logger::MLJFlow.Logger)
") using MLFlow API version $(logger.service.apiversion)"
)
end

"""
close(logger::Logger)
Each logger instance has a background loop that allows to execute the logging
operations from the `_logging_channel`. This function closes the channel
to stop the background loop.
"""
function close(logger::Logger)
close(logger._logging_channel)
end

"""
open_logging_channel(logger::Logger)
To allow safe concurrent logging operations, this function opens the
`_logging_channel` of the logger and starts a background worker.
"""
function open_logging_channel()
logging_channel = Channel{Tuple}()

# NOTE: This background loop allows to execute the logging operations from
# the logging_channel. The execution result is sent back to the
# requesting thread through the result_channel.
# Until May 2024, mlflow does not support concurrent experiment creation,
# which does not allow to run the logging operations in multi-threading and
# multi-processing.
#
# Its usage can be seen in the `log_evaluation` function in `base.jl`.
Threads.@spawn for (logging_function, logger, performance_evaluation, result_channel) in logging_channel
result = logging_function(logger, performance_evaluation)
put!(result_channel, result)
end

return logging_channel
end
4 changes: 2 additions & 2 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
measures=[LogLoss(), Accuracy()], verbosity=1, logger=logger)

@testset "log_evaluation" begin
runs = searchruns(logger.service,
getexperiment(logger.service, logger.experiment_name))
experiment = getexperiment(logger.service, logger.experiment_name)
runs = searchruns(logger.service, experiment)
@test typeof(runs[1]) == MLFlowRun
end

Expand Down
38 changes: 38 additions & 0 deletions test/multiprocessing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@testset verbose = true "multiprocessing" begin
logger = MLJFlow.Logger(ENV["MLFLOW_TRACKING_URI"];
experiment_name="MLJFlow multiprocessing tests",
artifact_location="/tmp/mlj-test")

X, y = make_moons(100)
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree

model = DecisionTreeClassifier()
r = range(model, :max_depth, lower=1, upper=6)

function test_tuned_model(acceleration_method)
tuned_model = TunedModel(
model=model,
range=r,
logger=logger,
acceleration=acceleration_method,
n=100,
)
tuned_model_mach = machine(tuned_model, X, y)
fit!(tuned_model_mach)

experiment = getorcreateexperiment(logger.service, logger.experiment_name)
runs = searchruns(logger.service, experiment)

@assert length(runs) == 100

deleteexperiment(logger.service, experiment)
end

@testset "log_evaluation_with_cpu_threads" begin
test_tuned_model(CPUThreads())
end

@testset "log_evaluation_with_cpu_processes" begin
test_tuned_model(CPUProcesses())
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using MLJFlow

using MLJBase
using MLJModels
using MLJTuning
using MLFlowClient
using MLJModelInterface
using StatisticalMeasures
Expand All @@ -21,4 +22,4 @@ end
include("base.jl")
include("types.jl")
include("service.jl")

include("multiprocessing.jl")

0 comments on commit 8a10ac3

Please sign in to comment.