diff --git a/src/optics.jl b/src/optics.jl index ef02e4c2..aa443a49 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -129,24 +129,14 @@ end struct Changed end struct Unchanged end -struct ConstructIfChanged{C} - constructor::C -end - -# TODO what do we call these things? -struct Construct end -_constructor(::Construct, ::Type{T}) where T = constructorof(T) - struct MaybeConstruct end -function _constructor(::MaybeConstruct, ::Type{T}) where T - ConstructIfChanged(constructorof(T)) -end +_constructor(::MaybeConstruct, ::Type{T}) where T = constructorof(T) struct List end _constructor(::List, ::Type) = tuple -struct Splat end -_constructor(::Splat, ::Type) = _splat_all +struct Skip end +_constructor(::Skip, ::Type) = _splat_all _splat_all(args...) = _splat_all(args) @generated function _splat_all(args::A) where A<:Tuple @@ -297,7 +287,7 @@ function mapobject(f, obj::O, ::Properties, handler, itr::Nothing) where O # TODO move this helper elsewhere? pnames = propertynames(obj) if isempty(pnames) - return obj + return _maybeskip(handler, obj) else new_props = map(pnames) do p f(getproperty(obj, p)) @@ -344,56 +334,69 @@ $EXPERIMENTAL """ struct Fields <: ObjectMap end -@generated function mapobject(f, obj::O, ::Fields, handler::H=Construct(), itr::I=nothing) where {O,H,I} +@generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Nothing) where {O,H,I} # TODO: This is how Flatten.jl works, but it's not really # correct use of ConstructionBase as it assumers properties=fields fnames = fieldnames(O) ctr = _constructor(H(), O) if isempty(fnames) - :(return _maybeitr(obj, itr)) + :(return _maybeskip(handler, obj)) else prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames) prop_exp = Expr(:tuple, prop_args...) - if I === Nothing - new_prop_exp = Expr(:tuple, map(pa -> :(f($pa)), prop_args)...) - else - ### Unrolled iterating function appliation (it will compile away) #### - # Each function call also updates the iterator value in local scoope with - # the return value from the function. But it only actually inserts the - # value into the parent tuple. - val_exps = map(prop_args) do pa - :((val, itr) = f($pa, itr); val) - end - new_prop_exp = Expr(:tuple, val_exps...) + new_prop_exp = Expr(:tuple, map(pa -> :(f($pa)), prop_args)...) + quote + props = $prop_exp + new_props = $new_prop_exp + return $ctr(new_props...) end - ret = if H == MaybeConstruct - quote - # TODO: last type instability. - # replace this with val => Changed(), val => Unchanged() - # return values. - # - # Don't construct when we don't absolutely have to. - # `constructorof` may not be defined for an object. - if props === new_props - return _maybeitr(obj, itr) - else - return _maybeitr($ctr(new_props...), itr) - end - end - else - ret = :(return _maybeitr($ctr(new_props...), itr)) + end +end +@generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Int) where {O,H} + # TODO: This is how Flatten.jl works, but it's not really + # correct use of ConstructionBase as it assumers properties=fields + fnames = fieldnames(O) + ctr = _constructor(H(), O) + if isempty(fnames) + :(return (obj, itr) => Unchanged()) + else + prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames) + prop_exp = Expr(:tuple, prop_args...) + ### Unrolled iterating function appliation (it will compile away) #### + # Each function call also updates the iterator value in local scoope with + # the return value from the function. But it only actually inserts the + # value into the parent tuple. + val_exps = map(prop_args) do pa + :(((val, itr), change) = f($pa, itr); val => change) end + new_prop_exp = Expr(:tuple, val_exps...) quote props = $prop_exp new_props = $new_prop_exp - $ret + new_props, change = _splitchanged(new_props) + # Don't construct when we don't absolutely have to. + # `constructorof` may not be defined for an object. + if change isa Changed + return ($ctr(new_props...), itr) => change + else + return (obj, itr) => change + end end end end +_splitchanged(props) = map(first, props), _findchanged(map(last, props)) + +_findchanged(::Tuple{Changed,Vararg}) = Changed() +_findchanged(cs::Tuple) = _findchanged(Base.tail(cs)) +_findchanged(::Tuple{}) = Unchanged() + _maybeitr(x, ::Nothing) = x _maybeitr(x, itr) = x, itr +_maybeskip(::Skip, v) = () +_maybeskip(x, v) = v + """ Recursive(descent_condition, optic) @@ -499,7 +502,7 @@ Query(select, descend = x -> true) = Query(select, descend, Fields()) Query(; select=Any, descend=x -> true, optic=Fields()) = Query(select, descend, optic) function (q::Query)(obj) - mapobject(obj, _inner(q.optic), Splat(), nothing) do o + mapobject(obj, _inner(q.optic), Skip(), nothing) do o if q.select_condition(o) (_getouter(o, q.optic),) elseif q.descent_condition(o) @@ -510,16 +513,16 @@ function (q::Query)(obj) end end -set(obj, q::Query, vals) = _set(obj, q::Query, (vals, 1))[1] +set(obj, q::Query, vals) = _set(obj, q::Query, (vals, 1))[1][1] function _set(obj, q::Query, (vals, itr)) - mapobject(obj, _inner(q.optic), Construct(), itr) do o, itr + mapobject(obj, _inner(q.optic), MaybeConstruct(), itr) do o, itr if q.select_condition(o) - _setouter(o, q.optic, vals[itr]), itr + 1 + (_setouter(o, q.optic, vals[itr]), itr + 1) => Changed() elseif q.descent_condition(o) _set(o, q, (vals, itr)) else - o, itr + (o, itr) => Unchanged() end end end diff --git a/test/test_queries.jl b/test/test_queries.jl index 592d8cfd..96bc2236 100644 --- a/test/test_queries.jl +++ b/test/test_queries.jl @@ -26,11 +26,14 @@ println("get") @btime $slowlens($obj) @test lens(obj) == slowlens(obj) == (17.0, 6.0) +missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2))) +@test Query(ismissing)(missings_obj) === (missing, missing, missing) + println("set") # Need a wrapper so we don't have to pass in the starting iterator @btime Accessors.set($obj, $lens, $vals) @btime Accessors._set($obj, $lens, ($vals, 1))[1] -@btime Accessors.set($obj, $slowlens, $vals) +# @btime Accessors.set($obj, $slowlens, $vals) Accessors.set(obj, lens, vals) @test Accessors.set(obj, lens, vals) == Accessors.set(obj, lens, vals) == @@ -47,3 +50,4 @@ unstable_lens = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> # Somehow modify compiles away almost completely @btime modify(x -> 10x, $obj, $lens) @test modify(x -> 10x, obj, lens) == (7, (a=170.0, b=2.0f0), ("3", 4, 5.0), ((a=60.0,), [1])) +