Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a Handler hierarchy and DictHandler #50

Merged
merged 2 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpoints"
uuid = "b4a3413d-e481-5afc-88ff-bdfbd6a50dce"
authors = "Invenia Technical Computing Corporation"
version = "0.3.20"
version = "0.3.21"

[deps]
AWSS3 = "1c724243-ef5b-51ab-93f4-b0a88ac62a95"
Expand Down
22 changes: 11 additions & 11 deletions src/Checkpoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ __init__() = Memento.register(LOGGER)

include("handler.jl")

const CHECKPOINTS = Dict{String, Union{Nothing, String, Handler}}()
const CHECKPOINTS = Dict{String, Union{Nothing, String, AbstractHandler}}()
@contextvar CONTEXT_TAGS::Tuple{Vararg{Pair{Symbol, Any}}} = Tuple{}()

include("session.jl")
Expand Down Expand Up @@ -75,7 +75,7 @@ available() = collect(keys(CHECKPOINTS))
Returns a vector of all enabled ([`config`](@ref)ured) and not [`deprecate`](@ref)d checkpoints.
Use [`deprecated_checkpoints`](@ref) to retrieve a mapping of old / deprecated checkpoints.
"""
enabled_checkpoints() = filter(k -> CHECKPOINTS[k] isa Handler, available())
enabled_checkpoints() = filter(k -> CHECKPOINTS[k] isa AbstractHandler, available())

"""
deprecated_checkpoints() -> Dict{String, String}
Expand Down Expand Up @@ -130,31 +130,31 @@ function checkpoint(prefix::Union{Module, String}, name::String, args...; tags..
end

"""
config(handler::Handler, labels::Vector{String})
config(handler::Handler, prefix::String)
config(handler::AbstractHandler, labels::Vector{String})
config(handler::AbstractHandler, prefix::String)
config(labels::Vector{String}, args...; kwargs...)
config(prefix::String, args...; kwargs...)

Configures the specified checkpoints with a `Handler`.
If the first argument is not a `Handler` then all `args` and `kwargs` are passed to a
`Handler` constructor for you.
Configures the specified checkpoints with a `AbstractHandler`.
If the first argument is not an `AbstractHandler` then all `args` and `kwargs` are
passed to a `JLSOHandler` constructor for you.
"""
function config(handler::Handler, names::Vector{String})
function config(handler::AbstractHandler, names::Vector{String})
for n in names
_config(handler, n)
end
end

function config(handler::Handler, prefix::Union{Module, String})
function config(handler::AbstractHandler, prefix::Union{Module, String})
config(handler, filter(l -> startswith(l, prefix), available()))
end

function config(names::Vector{String}, args...; kwargs...)
config(Handler(args...; kwargs...), names)
config(JLSOHandler(args...; kwargs...), names)
end

function config(prefix::Union{Module, String}, args...; kwargs...)
config(Handler(args...; kwargs...), prefix)
config(JLSOHandler(args...; kwargs...), prefix)
end

# To avoid collisions with `prefix` method above, which should probably use
Expand Down
2 changes: 2 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ function checkpoint_deprecation(tags...)
:checkpoint
)
end

Base.@deprecate_binding Handler JLSOHandler
115 changes: 76 additions & 39 deletions src/handler.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,82 @@
struct Handler{P<:AbstractPath}
path::P
settings # Could be Vector or Pairs on 0.6 or 1.0 respectively
end
abstract type AbstractHandler end

"""
Handler(path::Union{String, AbstractPath}; kwargs...)
Handler(bucket::String, prefix::String; kwargs...)
getkey(handler, name, separator="/") -> String

Handles iteratively saving JLSO file to the specified path location.
FilePath are used to abstract away differences between paths on S3 or locally.
Combine the `CONTEXT_TAGS` and `name` into a unique checkpoint key as a string.
If the checkpoint name includes `.`, usually representing nested modules, these are
also replaced with the provided separator.
"""
Handler(path::AbstractPath; kwargs...) = Handler(path, kwargs)
Handler(path::String; kwargs...) = Handler(Path(path), kwargs)
Handler(bucket::String, prefix::String; kwargs...) = Handler(S3Path("s3://$bucket/$prefix"), kwargs)
function getkey(::AbstractHandler, name::String, separator="/")::String
prefix = ["$key=$val" for (key, val) in CONTEXT_TAGS[]]
parts = split(name, '.') # Split up the name by '.'
return Base.join(vcat(prefix, parts), separator)
end

"""
path(handler, name)
path(args...) = Path(getkey(args...))

Determines the path to save to based on the handlers path prefix, name, and context.
Tags are used to dynamically prefix the named file with the handler's path.
Names with a '.' separators will be used to form subdirectories
(e.g., "Foo.bar.x" will be saved to "\$prefix/Foo/bar/x.jlso").
"""
function path(handler::Handler{P}, name::String) where P
prefix = ["$key=$val" for (key,val) in CONTEXT_TAGS[]]
stage!(handler::AbstractHandler, objects, data::Dict{Symbol})

# Split up the name by '.' and add the jlso extension
parts = split(name, '.')
parts[end] = string(parts[end], ".jlso")
Update the objects with the new data.
By default all handlers assume objects implements the associative interface.
"""
function stage!(handler::AbstractHandler, objects, data::Dict{Symbol})
for (k, v) in data
objects[k] = v
end

return join(handler.path, prefix..., parts...)
return objects
end

"""
stage!(handler::Handler, jlso::JLSOFIle, data::Dict{Symbol})
commit!(handler, prefix, objects)

Update the JLSOFile with the new data.
Serialize and write objects to a given path/prefix/key as defined by the handler.
"""
function stage!(handler::Handler, jlso::JLSO.JLSOFile, data::Dict{Symbol})
for (k, v) in data
jlso[k] = v
commit!

#=
Define our no-op conditions just to be safe
=#
function checkpoint(handler::Nothing, name::String, data::Dict{Symbol}; tags...)
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
debug(LOGGER, "Checkpoint $name triggered, but no handler has been set.")
nothing
end
end


return jlso
struct JLSOHandler{P<:AbstractPath} <: AbstractHandler
path::P
settings # Could be Vector or Pairs on 0.6 or 1.0 respectively
end

"""
commit!(handler, path, jlso)
JLSOHandler(path::Union{String, AbstractPath}; kwargs...)
JLSOHandler(bucket::String, prefix::String; kwargs...)

Handles iteratively saving JLSO file to the specified path location.
FilePath are used to abstract away differences between paths on S3 or locally.
"""
JLSOHandler(path::AbstractPath; kwargs...) = JLSOHandler(path, kwargs)
JLSOHandler(path::String; kwargs...) = JLSOHandler(Path(path), kwargs)
JLSOHandler(bucket::String, prefix::String; kwargs...) = JLSOHandler(S3Path("s3://$bucket/$prefix"), kwargs)

"""
path(handler, name)

Write the JLSOFile to the path as bytes.
Determines the path to save to based on the handlers path prefix, name, and context.
Tags are used to dynamically prefix the named file with the handler's path.
Names with a '.' separators will be used to form subdirectories
(e.g., "Foo.bar.x" will be saved to "\$prefix/Foo/bar/x.jlso").
"""
function commit!(handler::Handler{P}, path::P, jlso::JLSO.JLSOFile) where P <: AbstractPath
function path(handler::JLSOHandler{P}, name::String) where P
return join(handler.path, getkey(handler, name) * ".jlso")
end

function commit!(handler::JLSOHandler{P}, path::P, jlso::JLSO.JLSOFile) where P <: AbstractPath
# NOTE: This is only necessary because FilePathsBase.FileBuffer needs to support
# write(::FileBuffer, ::UInt8)
# https://github.com/rofinn/FilePathsBase.jl/issues/45
Expand All @@ -61,7 +87,7 @@ function commit!(handler::Handler{P}, path::P, jlso::JLSO.JLSOFile) where P <: A
write(path, bytes)
end

function checkpoint(handler::Handler, name::String, data::Dict{Symbol}; tags...)
function checkpoint(handler::JLSOHandler, name::String, data::Dict{Symbol}; tags...)
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).")
Expand All @@ -72,13 +98,24 @@ function checkpoint(handler::Handler, name::String, data::Dict{Symbol}; tags...)
end
end

#=
Define our no-op conditions just to be safe
=#
function checkpoint(handler::Nothing, name::String, data::Dict{Symbol}; tags...)
"""
DictHandler(objects)

Saves checkpointed objects into a dictionary where the keys are strings generated from
the checkpoint tags and name.
"""
struct DictHandler <: AbstractHandler
objects::Dict{String, Dict}
DictHandler() = new(Dict{String, Dict}())
end

commit!(handler::DictHandler, k::AbstractString, data) = setindex!(handler.objects, data, k)

function checkpoint(handler::DictHandler, name::String, data::Dict{Symbol}; tags...)
# TODO: Remove duplicate wrapper code
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
debug(LOGGER, "Checkpoint $name triggered, but no handler has been set.")
nothing
debug(LOGGER, "Checkpoint $name triggered, with context: $(join(CONTEXT_TAGS[], ", ")).")
handler.objects[getkey(handler, name)] = data
morris25 marked this conversation as resolved.
Show resolved Hide resolved
end
end
38 changes: 25 additions & 13 deletions src/session.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct Session{H<:Union{Nothing, Handler}}
struct Session{H<:Union{Nothing, AbstractHandler}}
name::String
handler::H
objects::DefaultDict
Expand All @@ -8,11 +8,7 @@ function Session(name::String)
# Create our objects dictionary which defaults to returning
# an empty JLSOFile
handler = CHECKPOINTS[name]

objects = DefaultDict{AbstractPath, JLSO.JLSOFile}() do
JLSO.JLSOFile(Dict{Symbol, Vector{UInt8}}(); handler.settings...)
end

objects = session_objects(handler)
Session{typeof(handler)}(name, handler, objects)
end

Expand All @@ -34,29 +30,45 @@ function Session(f::Function, prefix::Union{Module, String}, names::Vector{Strin
Session(f, map(n -> "$prefix.$n", names))
end

function session_objects(handler)
return DefaultDict{AbstractString, Dict}() do
Dict{Symbol, Any}()
end
end

function session_objects(handler::JLSOHandler)
return DefaultDict{AbstractPath, JLSO.JLSOFile}() do
JLSO.JLSOFile(Dict{Symbol, Vector{UInt8}}(); handler.settings...)
end
end

"""
commit!(session)

Write all staged JLSOFiles to the respective paths.
Write all staged objects to the respective keys.
"""
function commit!(session::Session)
# No-ops skip when handler is nothing
session.handler === nothing && return nothing

for (p, jlso) in session.objects
commit!(session.handler, p, jlso)
for (k, v) in session.objects
commit!(session.handler, k, v)
end
end

function checkpoint(session::Session, data::Dict{Symbol}; tags...)
checkpoint_deprecation(tags...)
with_checkpoint_tags(tags...) do
handler = session.handler
name = session.name
K = keytype(session.objects)

# No-ops skip when handler is nothing
session.handler === nothing && return nothing
handler === nothing && return nothing

p = path(session.handler, session.name)
jlso = session.objects[p]
session.objects[p] = stage!(session.handler, jlso, data)
# Our handler may not always be storing data in filepaths
k = K <: AbstractPath ? path(handler, name) : getkey(handler, name)
session.objects[k] = stage!(handler, session.objects[k], data)
end
end

Expand Down
36 changes: 36 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Test
using AWS: AWSConfig
using AWSS3: S3Path, s3_put, s3_list_buckets, s3_create_bucket
using Tables: Tables
using Checkpoints: JLSOHandler, DictHandler

Distributed.addprocs(5)
@everywhere using Checkpoints
Expand Down Expand Up @@ -256,5 +257,40 @@ Distributed.addprocs(5)
@test data[:data] == b
end
end

# We're largely reusing the same code for different handlers, but make sure
# that saving to a dict also works.
@testset "DictHandler" begin
a = Dict(zip(
map(x -> Symbol(randstring(4)), 1:10),
map(x -> rand(10), 1:10)
))
b = rand(10)
handler = DictHandler()
objects = handler.objects
Checkpoints.config(handler, "TestPkg")

@test isempty(handler.objects)
TestPkg.foo(x, y)
@test haskey(objects, "TestPkg/foo")
@test issetequal(keys(objects["TestPkg/foo"]), [:x, :y])
@test objects["TestPkg/foo"][:x] == x
@test objects["TestPkg/foo"][:y] == y

TestPkg.bar(b)
@test haskey(objects, "date=2017-01-01/TestPkg/bar")
@test objects["date=2017-01-01/TestPkg/bar"][:data] == b

TestPkg.baz(a)
@test haskey(objects, "TestPkg/baz")
@test objects["TestPkg/baz"] == a

TestPkg.qux(a, b)
@test haskey(objects, "TestPkg/qux_a")
@test objects["TestPkg/qux_a"] == a

@test haskey(objects, "TestPkg/qux_b")
@test objects["TestPkg/qux_b"][:data] == b
end
end
end