Skip to content

Commit

Permalink
add macros
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Jun 8, 2021
1 parent e6e851b commit ec26d1c
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 103 deletions.
172 changes: 78 additions & 94 deletions src/optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ _constructor(::MaybeConstruct, ::Type{T}) where T = constructorof(T)
struct List end
_constructor(::List, ::Type) = tuple

struct Skip end
_constructor(::Skip, ::Type) = _splat_all
struct Splat end
_constructor(::Splat, ::Type) = _splat_all

_splat_all(args...) = _splat_all(args)
@generated function _splat_all(args::A) where A<:Tuple
Expand Down Expand Up @@ -242,7 +242,7 @@ end
abstract type ObjectMap end

OpticStyle(::Type{<:ObjectMap}) = ModifyBased()
modify(f, o, optic::ObjectMap) = mapobject(f, o, optic, Construct, nothing)
modify(f, o, optic::ObjectMap) = mapobject(f, o, optic, Construct)

"""
Properties()
Expand Down Expand Up @@ -283,17 +283,17 @@ julia> Accessors.mapobject(x -> x+1, obj)
```
$EXPERIMENTAL
"""
function mapobject(f, obj::O, ::Properties, handler, itr::Nothing) where O
function mapobject(f, obj::O, ::Properties, handler) where O
# TODO move this helper elsewhere?
pnames = propertynames(obj)
if isempty(pnames)
return _maybeskip(handler, obj)
return skip(handler) ? () : obj
else
new_props = map(pnames) do p
ctr = _constructor(handler, O)
args = map(pnames) do p
f(getproperty(obj, p))
end
ctr = _constructor(handler, O)
return ctr(new_props...)
return ctr(args...)
end
end
function mapobject(f, obj::O, ::Properties, handler, itr::Int) where O
Expand Down Expand Up @@ -334,68 +334,50 @@ $EXPERIMENTAL
"""
struct Fields <: ObjectMap end

@generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Nothing) where {O,H,I}
@generated function mapobject(f, obj::O, ::Fields, handler::H) where {O,H,I}
# TODO: This is how Flatten.jl works, but it's not really
# correct use of ConstructionBase as it assumers properties=fields
fnames = fieldnames(O)
ctr = _constructor(H(), O)
if isempty(fnames)
:(return _maybeskip(handler, obj))
skip(H()) ? :(()) : :(obj)
else
prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames)
prop_exp = Expr(:tuple, prop_args...)
new_prop_exp = Expr(:tuple, map(pa -> :(f($pa)), prop_args)...)
quote
props = $prop_exp
new_props = $new_prop_exp
return $ctr(new_props...)
args = map(fnames) do fn
:(f(getfield(obj, $(QuoteNode(fn)))))
end
args_exp = Expr(:tuple, args...)
return :($ctr($args_exp...))
end
end
@generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Int) where {O,H}
# TODO: This is how Flatten.jl works, but it's not really
# correct use of ConstructionBase as it assumers properties=fields
fnames = fieldnames(O)
ctr = _constructor(H(), O)
if isempty(fnames)
:(return (obj, itr) => Unchanged())
:(obj => Unchanged(), itr)
else
prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames)
prop_exp = Expr(:tuple, prop_args...)
### Unrolled iterating function appliation (it will compile away) ####
# Each function call also updates the iterator value in local scoope with
# the return value from the function. But it only actually inserts the
# value into the parent tuple.
val_exps = map(prop_args) do pa
:(((val, itr), change) = f($pa, itr); val => change)
end
new_prop_exp = Expr(:tuple, val_exps...)
quote
props = $prop_exp
new_props = $new_prop_exp
new_props, change = _splitchanged(new_props)
# Don't construct when we don't absolutely have to.
# `constructorof` may not be defined for an object.
if change isa Changed
return ($ctr(new_props...), itr) => change
else
return (obj, itr) => change
end
### Unrolled iterating function appliation ####
# Each function call updates the iterator value in
# local scoope with its return value
args = map(fnames) do fn
:((val, itr) = f(getfield(obj, $(QuoteNode(fn))), itr); val)
end
args_exp = Expr(:tuple, args...)
return :(_maybeconstruct(obj, $args_exp, handler), itr)
end
end

_splitchanged(props) = map(first, props), _findchanged(map(last, props))

_findchanged(::Tuple{Changed,Vararg}) = Changed()
_findchanged(cs::Tuple) = _findchanged(Base.tail(cs))
_findchanged(::Tuple{}) = Unchanged()

_maybeitr(x, ::Nothing) = x
_maybeitr(x, itr) = x, itr
# Don't construct when we don't absolutely have to.
# `constructorof` may not be defined for an object.
@generated function _maybeconstruct(obj::O, props::P, handler::H) where {O,P,H}
ctr = _constructor(H(), O)
if Changed in map(last fieldtypes, fieldtypes(P))
:($ctr(map(first, props)...) => Changed())
else
:(obj => Unchanged())
end
end

_maybeskip(::Skip, v) = ()
_maybeskip(x, v) = v
skip(::Splat) = true
skip(x) = false

"""
Recursive(descent_condition, optic)
Expand Down Expand Up @@ -433,45 +415,8 @@ function _modify(f, obj, r::Recursive, ::ModifyBased)
end
end

################################################################################
##### Lenses
################################################################################
struct PropertyLens{fieldname} end

function (l::PropertyLens{field})(obj) where {field}
getproperty(obj, field)
end

@inline function set(obj, l::PropertyLens{field}, val) where {field}
patch = (;field => val)
setproperties(obj, patch)
end

struct IndexLens{I <: Tuple}
indices::I
end

Base.@propagate_inbounds function (lens::IndexLens)(obj)
getindex(obj, lens.indices...)
end
Base.@propagate_inbounds function set(obj, lens::IndexLens, val)
setindex(obj, val, lens.indices...)
end

struct DynamicIndexLens{F}
f::F
end

Base.@propagate_inbounds function (lens::DynamicIndexLens)(obj)
return obj[lens.f(obj)...]
end

Base.@propagate_inbounds function set(obj, lens::DynamicIndexLens, val)
return setindex(obj, val, lens.f(obj)...)
end

"""
Query(select, descend)
Query(select, descend, optic)
Query an object recursively, choosing fields when `select`
returns `true`, and descending when `descend`.
Expand Down Expand Up @@ -501,8 +446,10 @@ end
Query(select, descend = x -> true) = Query(select, descend, Fields())
Query(; select=Any, descend=x -> true, optic=Fields()) = Query(select, descend, optic)

OpticStyle(::Type{<:Query}) = SetBased()

function (q::Query)(obj)
mapobject(obj, _inner(q.optic), Skip(), nothing) do o
mapobject(obj, _inner(q.optic), Splat()) do o
if q.select_condition(o)
(_getouter(o, q.optic),)
elseif q.descent_condition(o)
Expand All @@ -518,11 +465,11 @@ set(obj, q::Query, vals) = _set(obj, q::Query, (vals, 1))[1][1]
function _set(obj, q::Query, (vals, itr))
mapobject(obj, _inner(q.optic), MaybeConstruct(), itr) do o, itr
if q.select_condition(o)
(_setouter(o, q.optic, vals[itr]), itr + 1) => Changed()
_setouter(o, q.optic, vals[itr]) => Changed(), itr + 1
elseif q.descent_condition(o)
_set(o, q, (vals, itr))
_set(o, q, (vals, itr)) # Will be marked as Changed()/Unchanged()
else
(o, itr) => Unchanged()
o => Unchanged(), itr
end
end
end
Expand All @@ -535,3 +482,40 @@ _getouter(o, optic::ComposedOptic) = optic.outer(o)
_getouter(o, optic) = o
_setouter(o, optic::ComposedOptic, v) = set(o, optic.outer, v)
_setouter(o, optic, v) = v

################################################################################
##### Lenses
################################################################################
struct PropertyLens{fieldname} end

function (l::PropertyLens{field})(obj) where {field}
getproperty(obj, field)
end

@inline function set(obj, l::PropertyLens{field}, val) where {field}
patch = (;field => val)
setproperties(obj, patch)
end

struct IndexLens{I <: Tuple}
indices::I
end

Base.@propagate_inbounds function (lens::IndexLens)(obj)
getindex(obj, lens.indices...)
end
Base.@propagate_inbounds function set(obj, lens::IndexLens, val)
setindex(obj, val, lens.indices...)
end

struct DynamicIndexLens{F}
f::F
end

Base.@propagate_inbounds function (lens::DynamicIndexLens)(obj)
return obj[lens.f(obj)...]
end

Base.@propagate_inbounds function set(obj, lens::DynamicIndexLens, val)
return setindex(obj, val, lens.f(obj)...)
end
34 changes: 32 additions & 2 deletions src/sugar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export @set, @optic, @reset, @modify
export @set, @optic, @reset, @modify, @getall, @setall
using MacroTools

"""
Expand Down Expand Up @@ -84,7 +84,6 @@ end
This function can be used to create a customized variant of [`@modify`](@ref).
See also [`opticmacro`](@ref), [`setmacro`](@ref).
"""

function modifymacro(optictransform, f, obj_optic)
f = esc(f)
obj, optic = parse_obj_optic(obj_optic)
Expand All @@ -94,6 +93,37 @@ function modifymacro(optictransform, f, obj_optic)
end)
end

"""
@getall f(obj, arg...)
@getall obj isa Number
"""
macro getall(ex)
ex.head == :call || error("@getall must be a function call")
obj = ex.args[2]
var = gensym()
ex.args[2] = var
esc(:(Query($var -> $ex)($obj)))
end

"""
@setall f(obj, arg...) = values
"""
macro setall(ex)
ex.head == :(=) || error("@setall must contain an = assignment")
func = ex.args[1]
vals = ex.args[2]
func.head == :call || error("@setall must contain a function call")
obj = func.args[2]
var = gensym()
func.args[2] = var
esc(:(set($obj, Query($var -> $func), $vals)))
end

dump(:(a = b))

foldtree(op, init, x) = op(init, x)
foldtree(op, init, ex::Expr) =
op(foldl((acc, x) -> foldtree(op, acc, x), ex.args; init=init), ex)
Expand Down
26 changes: 19 additions & 7 deletions test/test_queries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ slowlens = Query(;
optic = (Accessors.@optic _.a) Accessors.Properties()
)

@code_typed lens(obj)
@code_typed slowlens(obj)
lens(obj)
@code_warntype lens(obj)
@code_warntype slowlens(obj)

@code_native lens(obj)
@code_native slowlens(obj)
Expand All @@ -28,26 +29,37 @@ println("get")

missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2)))
@test Query(ismissing)(missings_obj) === (missing, missing, missing)
@btime Query(ismissing)($missings_obj) === (missing, missing, missing)

println("set")
# Need a wrapper so we don't have to pass in the starting iterator
@btime Accessors.set($obj, $lens, $vals)
set(obj, lens, vals)
@btime set($obj, $lens, $vals)
@btime Accessors._set($obj, $lens, ($vals, 1))[1]
# @btime Accessors.set($obj, $slowlens, $vals)
Accessors.set(obj, lens, vals)
@test Accessors.set(obj, lens, vals) ==
Accessors.set(obj, lens, vals) ==
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((a=2.0,), [1]))

@code_warntype Accessors.set(obj, lens, vals)
@code_warntype set(obj, lens, vals)
@code_native set(obj, lens, vals)
@code_native Accessors._set(obj, lens, (vals, 1))[1]

# using Cthulhu
# using ProfileView
# @profview for i in 1:1000000 Accessors.set(obj, lens, vals) end
# @descend Accessors.set(obj, lens, vals)

println("unstable set")
unstable_lens = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x isa NamedTuple)
@btime Accessors.set($obj, $unstable_lens, $vals)
@btime set($obj, $unstable_lens, $vals)
# slow_unstable_lens = Accessors.Query(; select=x -> x isa Number && x > 4, optic=Properties())
# @btime Accessors.set($obj, $slow_unstable_lens, $vals))

# Somehow modify compiles away almost completely
@btime modify(x -> 10x, $obj, $lens)
@test modify(x -> 10x, obj, lens) == (7, (a=170.0, b=2.0f0), ("3", 4, 5.0), ((a=60.0,), [1]))

# Macros
@test (@getall missings_obj isa Number) == (1, 2)
expected = (a=missing, b=5, c=(d=missing, e=(f=missing, g=6)))
@test (@setall missings_obj isa Number = (5, 6)) === expected

0 comments on commit ec26d1c

Please sign in to comment.