From ec26d1cb6180152d988ba2643dfd14c9babe2c0f Mon Sep 17 00:00:00 2001 From: Rafael Schouten Date: Tue, 8 Jun 2021 15:35:24 +1000 Subject: [PATCH] add macros --- src/optics.jl | 172 ++++++++++++++++++++----------------------- src/sugar.jl | 34 ++++++++- test/test_queries.jl | 26 +++++-- 3 files changed, 129 insertions(+), 103 deletions(-) diff --git a/src/optics.jl b/src/optics.jl index aa443a49..6f3b3e4c 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -135,8 +135,8 @@ _constructor(::MaybeConstruct, ::Type{T}) where T = constructorof(T) struct List end _constructor(::List, ::Type) = tuple -struct Skip end -_constructor(::Skip, ::Type) = _splat_all +struct Splat end +_constructor(::Splat, ::Type) = _splat_all _splat_all(args...) = _splat_all(args) @generated function _splat_all(args::A) where A<:Tuple @@ -242,7 +242,7 @@ end abstract type ObjectMap end OpticStyle(::Type{<:ObjectMap}) = ModifyBased() -modify(f, o, optic::ObjectMap) = mapobject(f, o, optic, Construct, nothing) +modify(f, o, optic::ObjectMap) = mapobject(f, o, optic, Construct) """ Properties() @@ -283,17 +283,17 @@ julia> Accessors.mapobject(x -> x+1, obj) ``` $EXPERIMENTAL """ -function mapobject(f, obj::O, ::Properties, handler, itr::Nothing) where O +function mapobject(f, obj::O, ::Properties, handler) where O # TODO move this helper elsewhere? pnames = propertynames(obj) if isempty(pnames) - return _maybeskip(handler, obj) + return skip(handler) ? () : obj else - new_props = map(pnames) do p + ctr = _constructor(handler, O) + args = map(pnames) do p f(getproperty(obj, p)) end - ctr = _constructor(handler, O) - return ctr(new_props...) + return ctr(args...) end end function mapobject(f, obj::O, ::Properties, handler, itr::Int) where O @@ -334,68 +334,50 @@ $EXPERIMENTAL """ struct Fields <: ObjectMap end -@generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Nothing) where {O,H,I} +@generated function mapobject(f, obj::O, ::Fields, handler::H) where {O,H,I} # TODO: This is how Flatten.jl works, but it's not really # correct use of ConstructionBase as it assumers properties=fields fnames = fieldnames(O) ctr = _constructor(H(), O) if isempty(fnames) - :(return _maybeskip(handler, obj)) + skip(H()) ? :(()) : :(obj) else - prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames) - prop_exp = Expr(:tuple, prop_args...) - new_prop_exp = Expr(:tuple, map(pa -> :(f($pa)), prop_args)...) - quote - props = $prop_exp - new_props = $new_prop_exp - return $ctr(new_props...) + args = map(fnames) do fn + :(f(getfield(obj, $(QuoteNode(fn))))) end + args_exp = Expr(:tuple, args...) + return :($ctr($args_exp...)) end end @generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Int) where {O,H} - # TODO: This is how Flatten.jl works, but it's not really - # correct use of ConstructionBase as it assumers properties=fields fnames = fieldnames(O) - ctr = _constructor(H(), O) if isempty(fnames) - :(return (obj, itr) => Unchanged()) + :(obj => Unchanged(), itr) else - prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames) - prop_exp = Expr(:tuple, prop_args...) - ### Unrolled iterating function appliation (it will compile away) #### - # Each function call also updates the iterator value in local scoope with - # the return value from the function. But it only actually inserts the - # value into the parent tuple. - val_exps = map(prop_args) do pa - :(((val, itr), change) = f($pa, itr); val => change) - end - new_prop_exp = Expr(:tuple, val_exps...) - quote - props = $prop_exp - new_props = $new_prop_exp - new_props, change = _splitchanged(new_props) - # Don't construct when we don't absolutely have to. - # `constructorof` may not be defined for an object. - if change isa Changed - return ($ctr(new_props...), itr) => change - else - return (obj, itr) => change - end + ### Unrolled iterating function appliation #### + # Each function call updates the iterator value in + # local scoope with its return value + args = map(fnames) do fn + :((val, itr) = f(getfield(obj, $(QuoteNode(fn))), itr); val) end + args_exp = Expr(:tuple, args...) + return :(_maybeconstruct(obj, $args_exp, handler), itr) end end -_splitchanged(props) = map(first, props), _findchanged(map(last, props)) - -_findchanged(::Tuple{Changed,Vararg}) = Changed() -_findchanged(cs::Tuple) = _findchanged(Base.tail(cs)) -_findchanged(::Tuple{}) = Unchanged() - -_maybeitr(x, ::Nothing) = x -_maybeitr(x, itr) = x, itr +# Don't construct when we don't absolutely have to. +# `constructorof` may not be defined for an object. +@generated function _maybeconstruct(obj::O, props::P, handler::H) where {O,P,H} + ctr = _constructor(H(), O) + if Changed in map(last ∘ fieldtypes, fieldtypes(P)) + :($ctr(map(first, props)...) => Changed()) + else + :(obj => Unchanged()) + end +end -_maybeskip(::Skip, v) = () -_maybeskip(x, v) = v +skip(::Splat) = true +skip(x) = false """ Recursive(descent_condition, optic) @@ -433,45 +415,8 @@ function _modify(f, obj, r::Recursive, ::ModifyBased) end end -################################################################################ -##### Lenses -################################################################################ -struct PropertyLens{fieldname} end - -function (l::PropertyLens{field})(obj) where {field} - getproperty(obj, field) -end - -@inline function set(obj, l::PropertyLens{field}, val) where {field} - patch = (;field => val) - setproperties(obj, patch) -end - -struct IndexLens{I <: Tuple} - indices::I -end - -Base.@propagate_inbounds function (lens::IndexLens)(obj) - getindex(obj, lens.indices...) -end -Base.@propagate_inbounds function set(obj, lens::IndexLens, val) - setindex(obj, val, lens.indices...) -end - -struct DynamicIndexLens{F} - f::F -end - -Base.@propagate_inbounds function (lens::DynamicIndexLens)(obj) - return obj[lens.f(obj)...] -end - -Base.@propagate_inbounds function set(obj, lens::DynamicIndexLens, val) - return setindex(obj, val, lens.f(obj)...) -end - """ - Query(select, descend) + Query(select, descend, optic) Query an object recursively, choosing fields when `select` returns `true`, and descending when `descend`. @@ -501,8 +446,10 @@ end Query(select, descend = x -> true) = Query(select, descend, Fields()) Query(; select=Any, descend=x -> true, optic=Fields()) = Query(select, descend, optic) +OpticStyle(::Type{<:Query}) = SetBased() + function (q::Query)(obj) - mapobject(obj, _inner(q.optic), Skip(), nothing) do o + mapobject(obj, _inner(q.optic), Splat()) do o if q.select_condition(o) (_getouter(o, q.optic),) elseif q.descent_condition(o) @@ -518,11 +465,11 @@ set(obj, q::Query, vals) = _set(obj, q::Query, (vals, 1))[1][1] function _set(obj, q::Query, (vals, itr)) mapobject(obj, _inner(q.optic), MaybeConstruct(), itr) do o, itr if q.select_condition(o) - (_setouter(o, q.optic, vals[itr]), itr + 1) => Changed() + _setouter(o, q.optic, vals[itr]) => Changed(), itr + 1 elseif q.descent_condition(o) - _set(o, q, (vals, itr)) + _set(o, q, (vals, itr)) # Will be marked as Changed()/Unchanged() else - (o, itr) => Unchanged() + o => Unchanged(), itr end end end @@ -535,3 +482,40 @@ _getouter(o, optic::ComposedOptic) = optic.outer(o) _getouter(o, optic) = o _setouter(o, optic::ComposedOptic, v) = set(o, optic.outer, v) _setouter(o, optic, v) = v + +################################################################################ +##### Lenses +################################################################################ +struct PropertyLens{fieldname} end + +function (l::PropertyLens{field})(obj) where {field} + getproperty(obj, field) +end + +@inline function set(obj, l::PropertyLens{field}, val) where {field} + patch = (;field => val) + setproperties(obj, patch) +end + +struct IndexLens{I <: Tuple} + indices::I +end + +Base.@propagate_inbounds function (lens::IndexLens)(obj) + getindex(obj, lens.indices...) +end +Base.@propagate_inbounds function set(obj, lens::IndexLens, val) + setindex(obj, val, lens.indices...) +end + +struct DynamicIndexLens{F} + f::F +end + +Base.@propagate_inbounds function (lens::DynamicIndexLens)(obj) + return obj[lens.f(obj)...] +end + +Base.@propagate_inbounds function set(obj, lens::DynamicIndexLens, val) + return setindex(obj, val, lens.f(obj)...) +end diff --git a/src/sugar.jl b/src/sugar.jl index 4dd90f8c..7ee9da06 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -1,4 +1,4 @@ -export @set, @optic, @reset, @modify +export @set, @optic, @reset, @modify, @getall, @setall using MacroTools """ @@ -84,7 +84,6 @@ end This function can be used to create a customized variant of [`@modify`](@ref). See also [`opticmacro`](@ref), [`setmacro`](@ref). """ - function modifymacro(optictransform, f, obj_optic) f = esc(f) obj, optic = parse_obj_optic(obj_optic) @@ -94,6 +93,37 @@ function modifymacro(optictransform, f, obj_optic) end) end +""" + @getall f(obj, arg...) + +@getall obj isa Number +""" +macro getall(ex) + ex.head == :call || error("@getall must be a function call") + obj = ex.args[2] + var = gensym() + ex.args[2] = var + esc(:(Query($var -> $ex)($obj))) +end + +""" + @setall f(obj, arg...) = values + + +""" +macro setall(ex) + ex.head == :(=) || error("@setall must contain an = assignment") + func = ex.args[1] + vals = ex.args[2] + func.head == :call || error("@setall must contain a function call") + obj = func.args[2] + var = gensym() + func.args[2] = var + esc(:(set($obj, Query($var -> $func), $vals))) +end + +dump(:(a = b)) + foldtree(op, init, x) = op(init, x) foldtree(op, init, ex::Expr) = op(foldl((acc, x) -> foldtree(op, acc, x), ex.args; init=init), ex) diff --git a/test/test_queries.jl b/test/test_queries.jl index 96bc2236..1eb9cf29 100644 --- a/test/test_queries.jl +++ b/test/test_queries.jl @@ -15,8 +15,9 @@ slowlens = Query(; optic = (Accessors.@optic _.a) ∘ Accessors.Properties() ) -@code_typed lens(obj) -@code_typed slowlens(obj) +lens(obj) +@code_warntype lens(obj) +@code_warntype slowlens(obj) @code_native lens(obj) @code_native slowlens(obj) @@ -28,26 +29,37 @@ println("get") missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2))) @test Query(ismissing)(missings_obj) === (missing, missing, missing) +@btime Query(ismissing)($missings_obj) === (missing, missing, missing) println("set") # Need a wrapper so we don't have to pass in the starting iterator -@btime Accessors.set($obj, $lens, $vals) +set(obj, lens, vals) +@btime set($obj, $lens, $vals) @btime Accessors._set($obj, $lens, ($vals, 1))[1] # @btime Accessors.set($obj, $slowlens, $vals) -Accessors.set(obj, lens, vals) @test Accessors.set(obj, lens, vals) == Accessors.set(obj, lens, vals) == (7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((a=2.0,), [1])) -@code_warntype Accessors.set(obj, lens, vals) +@code_warntype set(obj, lens, vals) +@code_native set(obj, lens, vals) +@code_native Accessors._set(obj, lens, (vals, 1))[1] + +# using Cthulhu +# using ProfileView +# @profview for i in 1:1000000 Accessors.set(obj, lens, vals) end +# @descend Accessors.set(obj, lens, vals) println("unstable set") unstable_lens = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x isa NamedTuple) -@btime Accessors.set($obj, $unstable_lens, $vals) +@btime set($obj, $unstable_lens, $vals) # slow_unstable_lens = Accessors.Query(; select=x -> x isa Number && x > 4, optic=Properties()) # @btime Accessors.set($obj, $slow_unstable_lens, $vals)) # Somehow modify compiles away almost completely @btime modify(x -> 10x, $obj, $lens) -@test modify(x -> 10x, obj, lens) == (7, (a=170.0, b=2.0f0), ("3", 4, 5.0), ((a=60.0,), [1])) +# Macros +@test (@getall missings_obj isa Number) == (1, 2) +expected = (a=missing, b=5, c=(d=missing, e=(f=missing, g=6))) +@test (@setall missings_obj isa Number = (5, 6)) === expected