Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing channel control for experiment creation #41

Merged
merged 6 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ julia = "1.6"
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"]
test = ["Test", "MLJModels", "MLJTuning", "MLFlowClient",
"StatisticalMeasures", "MLJDecisionTreeInterface"]
7 changes: 3 additions & 4 deletions src/MLJFlow.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
module MLJFlow

using MLJBase: Model, Machine, name
using MLJBase: Model, Machine, name
using MLJModelInterface: flat_params
using MLFlowClient: MLFlow, logparam, logmetric,
createrun, MLFlowRun, updaterun, logartifact,
getorcreateexperiment
using MLFlowClient: MLFlow, logparam, logmetric, createrun, MLFlowRun,
updaterun, logartifact, getorcreateexperiment

import Base: show
import MLJBase: save, log_evaluation
Expand Down
11 changes: 10 additions & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function log_evaluation(logger::Logger, performance_evaluation)
function _log_evaluation(logger::Logger, performance_evaluation)
experiment = getorcreateexperiment(logger.service, logger.experiment_name;
artifact_location=logger.artifact_location)
run = createrun(logger.service, experiment;
Expand All @@ -19,6 +19,15 @@ function log_evaluation(logger::Logger, performance_evaluation)
updaterun(logger.service, run, "FINISHED")
end

function log_evaluation(logger::Logger, performance_evaluation)
result_channel = Channel{MLFlowRun}(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this function should close result_channel before returning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result_channel is just temporal. When we perform take!(result_channel), the value is retrieved and returned. The garbage collector will take care off the channel when function ends, so it's safe to maintain it open.
Also, we are not looping the result_channel in any way, so we don't have to worry about that.

(I can remove the wait(result_channel) at line 26 because take!(result_channel) will actually wait until the Channel contains a value)


put!(logger._logging_channel, (_log_evaluation, logger, performance_evaluation, result_channel))
wait(result_channel)

return take!(result_channel)
end

function save(logger::Logger, machine:: Machine)
io = IOBuffer()
save(io, machine)
Expand Down
42 changes: 41 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ struct Logger
verbosity::Int
experiment_name::String
artifact_location::Union{String,Nothing}
_logging_channel::Channel{Tuple}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit picking point: Can't _logging_channel be logging_channel ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered _logging_channel instead of logging_channel because I want it to be like a "private" field. The user is not intended to touch that field. However, I discovered that I made a mistake on open_logging_channel() because I'm returning a value instead of setting it up in the logger.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in Julia fields are generally considered private, especially if they are not referenced in any doc-string. To make them public you would ordinarily provide an accessor function. But there's no harm in this - I just found it a bit strange.

end

function Logger(apiroot; experiment_name="MLJ experiment",
artifact_location=nothing, verbosity=1)
service = MLFlow(apiroot)
logging_channel = open_logging_channel()

Logger(service, verbosity, experiment_name, artifact_location)
Logger(service, verbosity, experiment_name, artifact_location, logging_channel)
end

function show(io::IO, logger::MLJFlow.Logger)
print(io,
"MLFLowLogger(\"$(logger.service.apiroot)\",\n" *
Expand All @@ -43,3 +47,39 @@ function show(io::IO, logger::MLJFlow.Logger)
") using MLFlow API version $(logger.service.apiversion)"
)
end

"""
close(logger::Logger)
Each logger instance has a background loop that allows to execute the logging
operations from the `_logging_channel`. This function closes the channel
to stop the background loop.
"""
function close(logger::Logger)
close(logger._logging_channel)
end

"""
open_logging_channel(logger::Logger)
To allow safe concurrent logging operations, this function opens the
`_logging_channel` of the logger and starts a background worker.
"""
function open_logging_channel()
logging_channel = Channel{Tuple}()

# NOTE: This background loop allows to execute the logging operations from
# the logging_channel. The execution result is sent back to the
# requesting thread through the result_channel.
# Until May 2024, mlflow does not support concurrent experiment creation,
# which does not allow to run the logging operations in multi-threading and
# multi-processing.
#
# Its usage can be seen in the `log_evaluation` function in `base.jl`.
Threads.@spawn for (logging_function, logger, performance_evaluation, result_channel) in logging_channel
result = logging_function(logger, performance_evaluation)
put!(result_channel, result)
end

return logging_channel
end
4 changes: 2 additions & 2 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
measures=[LogLoss(), Accuracy()], verbosity=1, logger=logger)

@testset "log_evaluation" begin
runs = searchruns(logger.service,
getexperiment(logger.service, logger.experiment_name))
experiment = getexperiment(logger.service, logger.experiment_name)
runs = searchruns(logger.service, experiment)
@test typeof(runs[1]) == MLFlowRun
end

Expand Down
38 changes: 38 additions & 0 deletions test/multiprocessing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@testset verbose = true "multiprocessing" begin
logger = MLJFlow.Logger(ENV["MLFLOW_TRACKING_URI"];
experiment_name="MLJFlow multiprocessing tests",
artifact_location="/tmp/mlj-test")

X, y = make_moons(100)
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree

model = DecisionTreeClassifier()
r = range(model, :max_depth, lower=1, upper=6)

function test_tuned_model(acceleration_method)
tuned_model = TunedModel(
model=model,
range=r,
logger=logger,
acceleration=acceleration_method,
n=100,
)
tuned_model_mach = machine(tuned_model, X, y)
fit!(tuned_model_mach)

experiment = getorcreateexperiment(logger.service, logger.experiment_name)
runs = searchruns(logger.service, experiment)

@assert length(runs) == 100

deleteexperiment(logger.service, experiment)
end

@testset "log_evaluation_with_cpu_threads" begin
test_tuned_model(CPUThreads())
end

@testset "log_evaluation_with_cpu_processes" begin
test_tuned_model(CPUProcesses())
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using MLJFlow

using MLJBase
using MLJModels
using MLJTuning
using MLFlowClient
using MLJModelInterface
using StatisticalMeasures
Expand All @@ -21,4 +22,4 @@ end
include("base.jl")
include("types.jl")
include("service.jl")

include("multiprocessing.jl")
Loading