-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing channel control for experiment creation #41
Changes from all commits
e11e294
c7ee07d
0e5a16c
6332639
420b235
f4d3e30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,13 +28,17 @@ struct Logger | |
verbosity::Int | ||
experiment_name::String | ||
artifact_location::Union{String,Nothing} | ||
_logging_channel::Channel{Tuple} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit picking point: Can't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in Julia fields are generally considered private, especially if they are not referenced in any doc-string. To make them public you would ordinarily provide an accessor function. But there's no harm in this - I just found it a bit strange. |
||
end | ||
|
||
function Logger(apiroot; experiment_name="MLJ experiment", | ||
artifact_location=nothing, verbosity=1) | ||
service = MLFlow(apiroot) | ||
logging_channel = open_logging_channel() | ||
|
||
Logger(service, verbosity, experiment_name, artifact_location) | ||
Logger(service, verbosity, experiment_name, artifact_location, logging_channel) | ||
end | ||
|
||
function show(io::IO, logger::MLJFlow.Logger) | ||
print(io, | ||
"MLFLowLogger(\"$(logger.service.apiroot)\",\n" * | ||
|
@@ -43,3 +47,39 @@ function show(io::IO, logger::MLJFlow.Logger) | |
") using MLFlow API version $(logger.service.apiversion)" | ||
) | ||
end | ||
|
||
""" | ||
close(logger::Logger) | ||
Each logger instance has a background loop that allows to execute the logging | ||
operations from the `_logging_channel`. This function closes the channel | ||
to stop the background loop. | ||
""" | ||
function close(logger::Logger) | ||
close(logger._logging_channel) | ||
end | ||
|
||
""" | ||
open_logging_channel(logger::Logger) | ||
To allow safe concurrent logging operations, this function opens the | ||
`_logging_channel` of the logger and starts a background worker. | ||
""" | ||
function open_logging_channel() | ||
logging_channel = Channel{Tuple}() | ||
|
||
# NOTE: This background loop allows to execute the logging operations from | ||
# the logging_channel. The execution result is sent back to the | ||
# requesting thread through the result_channel. | ||
# Until May 2024, mlflow does not support concurrent experiment creation, | ||
# which does not allow to run the logging operations in multi-threading and | ||
# multi-processing. | ||
# | ||
# Its usage can be seen in the `log_evaluation` function in `base.jl`. | ||
Threads.@spawn for (logging_function, logger, performance_evaluation, result_channel) in logging_channel | ||
result = logging_function(logger, performance_evaluation) | ||
put!(result_channel, result) | ||
end | ||
|
||
return logging_channel | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
@testset verbose = true "multiprocessing" begin | ||
logger = MLJFlow.Logger(ENV["MLFLOW_TRACKING_URI"]; | ||
experiment_name="MLJFlow multiprocessing tests", | ||
artifact_location="/tmp/mlj-test") | ||
|
||
X, y = make_moons(100) | ||
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree | ||
|
||
model = DecisionTreeClassifier() | ||
r = range(model, :max_depth, lower=1, upper=6) | ||
|
||
function test_tuned_model(acceleration_method) | ||
tuned_model = TunedModel( | ||
model=model, | ||
range=r, | ||
logger=logger, | ||
acceleration=acceleration_method, | ||
n=100, | ||
) | ||
tuned_model_mach = machine(tuned_model, X, y) | ||
fit!(tuned_model_mach) | ||
|
||
experiment = getorcreateexperiment(logger.service, logger.experiment_name) | ||
runs = searchruns(logger.service, experiment) | ||
|
||
@assert length(runs) == 100 | ||
|
||
deleteexperiment(logger.service, experiment) | ||
end | ||
|
||
@testset "log_evaluation_with_cpu_threads" begin | ||
test_tuned_model(CPUThreads()) | ||
end | ||
|
||
@testset "log_evaluation_with_cpu_processes" begin | ||
test_tuned_model(CPUProcesses()) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this function should close
result_channel
before returning?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
result_channel
is just temporal. When we performtake!(result_channel)
, the value is retrieved and returned. The garbage collector will take care off the channel when function ends, so it's safe to maintain it open.Also, we are not looping the
result_channel
in any way, so we don't have to worry about that.(I can remove the
wait(result_channel)
at line 26 becausetake!(result_channel)
will actually wait until the Channel contains a value)