Skip to content

Commit

Permalink
working but slow
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Jun 19, 2021
1 parent 8fb1580 commit 370e2fc
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 83 deletions.
130 changes: 64 additions & 66 deletions src/optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ julia> obj = (a=1, b=2); lens=@optic _.a; val = 100;
julia> set(obj, lens, val)
(a = 100, b = 2)
```
See also [`modify`](@ref).
``` See also [`modify`](@ref).
"""
function set end

Expand Down Expand Up @@ -346,15 +345,7 @@ Here `f` has signature `f(::Value, ::State) -> Tuple{NewValue, NewState}`.
"""
function modify_stateful end

@inline function modify_stateful(f, (obj, state), optic::Properties)
let f=f, obj=obj, state=state
modify_stateful_context((obj, state), optic) do _, fn, pr, st
f(getfield(pr, known(fn)), st)
end
end
end

@generated function modify_stateful_context(f, (obj, state1)::T, optic::Properties) where T
@generated function modify_stateful(f::F, (obj, state)::T, optic::Properties) where {T,F}
_modify_stateful_inner(T)
end

Expand All @@ -363,29 +354,29 @@ function _modify_stateful_inner(::Type{<:Tuple{O,S}}) where {O,S}
modifications = []
vals = Expr(:tuple)
fns = fieldnames(O)
local st1 = :state0
local st2 = :state1
for (i, fn) in enumerate(fns)
v = Symbol("val$i")
st1 = Symbol("state$i")
st2 = Symbol("state$(i+1)")
ms = if O <: Tuple
:(($v, $st2) = f(obj, StaticInt{$(QuoteNode(fn))}(), props, $st1))
st = if S <: ContextState
if O <: Tuple
:(ContextState(state.vals, obj, StaticInt{$(QuoteNode(fn))}()))
else
:(ContextState(state.vals, obj, StaticSymbol{$(QuoteNode(fn))}()))
end
else
:(($v, $st2) = f(obj, StaticSymbol{$(QuoteNode(fn))}(), props, $st1))
:state
end
ms = :(($v, state) = f(getfield(props, $(QuoteNode(fn))), $st))
push!(modifications, ms)
push!(vals.args, v)
end
patch = O <: Tuple ? vals : :(NamedTuple{$fns}($vals))
Expr(:block,
:(props = getproperties(obj)),
modifications...,
:(patch = $patch),
:(new_obj = maybesetproperties($st2, obj, patch)),
:(new_state = maybesetstate($st2, obj, patch)),
:(return (setproperties(obj, patch), $st2)),
)
start = :(props = getproperties(obj))
rest = MacroTools.@q begin
patch = $patch
new_obj = maybesetproperties(state, obj, patch)
return (new_obj, state)
end
Expr(:block, start, modifications..., rest)
end

maybesetproperties(state, obj, patch) = setproperties(obj, patch)
Expand Down Expand Up @@ -426,15 +417,10 @@ Query(; select=Any, descend=x -> true, optic=Properties()) = Query(select, desce

OpticStyle(::Type{<:AbstractQuery}) = SetBased()

struct Context{Select,Descend,Optic<:Union{ComposedOptic,Properties}} <: AbstractQuery
select_condition::Select
descent_condition::Descend
optic::Optic
end


struct ContextState{V}
struct ContextState{V,O,FN}
vals::V
obj::O
fn::FN
end
struct GetAllState{V}
vals::V
Expand All @@ -445,57 +431,69 @@ struct SetAllState{C,V,I}
itr::I
end

pop(x) = first(x), Base.tail(x)
push(x, val) = (x..., val)
push(x::GetAllState, val) = GetAllState(push(x.vals, val))
const GetStates = Union{GetAllState,ContextState}

@inline pop(x) = first(x), Base.tail(x)
@inline push(x, val) = (x..., val)
@inline push(x::GetAllState, val) = GetAllState(push(x.vals, val))
@inline push(x::ContextState, val) = ContextState(push(x.vals, val), nothing, nothing)

(q::Query)(obj) = getall(obj, q)

function getall(obj, q)
getall(obj, q) = _getall(obj, q).vals
function _getall(obj, q::Q) where Q<:Query
initial_state = GetAllState(())
_, final_state = modify_stateful((obj, initial_state), q) do o, s
new_state = push(s, outer(q.optic, o, s))
o, new_state
_, final_state = let q=q
modify_stateful((obj, initial_state), q) do o, s
new_state = push(s, outer(q.optic, o, s))
o, new_state
end
end
return final_state.vals
final_state
end

function setall(obj, q, vals)
function setall(obj, q::Q, vals) where Q<:Query
initial_state = SetAllState(Unchanged(), vals, 1)
final_obj, _ = modify_stateful((obj, initial_state), q) do o, s
new_output = outer(q.optic, o, s)
new_state = SetAllState(Changed(), s.vals, s.itr + 1)
new_output, new_state
final_obj, _ = let obj=obj, q=q, initial_state=initial_state
modify_stateful((obj, initial_state), q) do o, s
new_output = outer(q.optic, o, s)
new_state = SetAllState(Changed(), s.vals, s.itr + 1)
new_output, new_state
end
end
return final_obj
end

function context(f, obj, q)
initial_state = GetAllState(())
_, final_state = modify_stateful_context((obj, initial_state), Properties()) do o, fn, pr, s
new_state = push(s, f(o, known(fn)))
o, new_state
function context(f::F, obj, q::Q) where {F,Q<:Query}
initial_state = ContextState((), nothing, nothing)
_, final_state = let f=f
modify_stateful((obj, initial_state), q) do o, s
new_state = push(s, f(s.obj, known(s.fn)))
o, new_state
end
end
return final_state.vals
end

modify(f, obj, q::Query) = setall(obj, q, map(f, getall(obj, q)))

@inline function modify_stateful(f::F, (obj, state), q::Query) where F
modify_stateful((obj, state), inner(q.optic)) do o, s
if q.select_condition(o)
f(o, s)
elseif q.descent_condition(o)
ds = descent_state(s)
o, s = modify_stateful(f::F, (o, ds), q)
o, merge_state(s, ds)
else
o, s
@inline function modify_stateful(f::F, (obj, state), q::Q) where {F,Q<:Query}
let f=f, q=q
modify_stateful((obj, state), inner(q.optic)) do o, s
if (q::Q).select_condition(o)
(f::F)(o, s)
elseif (q::Q).descent_condition(o)
ds = descent_state(s)
o, ns = modify_stateful(f::F, (o, ds), q::Q)
o, merge_state(ds, ns)
else
o, s
end
end
end
end

maybesetproperties(state::GetAllState, obj, patch) = obj
maybesetproperties(state::GetStates, obj, patch) = obj
maybesetproperties(state::SetAllState, obj, patch) =
maybesetproperties(state.change, state, obj, patch)
maybesetproperties(::Changed, state::SetAllState, obj, patch) = setproperties(obj, patch)
Expand All @@ -516,8 +514,8 @@ anychanged(::Changed, ::Changed) = Changed()
inner(optic) = optic
inner(optic::ComposedOptic) = optic.inner

outer(optic, o, state::GetAllState) = o
outer(optic::ComposedOptic, o, state::GetAllState) = optic.outer(o)
outer(optic, o, state::GetStates) = o
outer(optic::ComposedOptic, o, state::GetStates) = optic.outer(o)
outer(optic::ComposedOptic, o, state::SetAllState) = set(o, optic.outer, state.vals[state.itr])
outer(optic, o, state::SetAllState) = state.vals[state.itr]

Expand All @@ -532,7 +530,7 @@ function (l::PropertyLens{field})(obj) where {field}
end

@inline function set(obj, l::PropertyLens{field}, val) where {field}
patch = (;field => val)
patch = (; field => val)
setproperties(obj, patch)
end

Expand Down
36 changes: 19 additions & 17 deletions test/test_queries.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,43 @@
using Accessors, Test, BenchmarkTools, Static
using Accessors: setall, getall, context

obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,), [1,]))
obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,)), [1])
vals = (1.0, 2.0, 3.0, 4.0)

# Fields is the default
q = Query(;
select=x -> x isa NamedTuple,
descend=x -> x isa Tuple,
optic = (Accessors.@optic _.a) Accessors.Properties()
# optic = Accessors.Properties()
)

println("getall")
getall(obj, q)

@code_native getall(obj, q)
@code_warntype getall(obj, q)

@benchmark getall($obj, $q)
@test getall(obj, q) == (17.0, 6.0)

# using ProfileView, Cthulhu
# @descend getall(obj, q)
# f(obj, q) = for i in 1:10000000 getall(obj, q) end
# @profview f(obj, q)

missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2)))
@test getall(missings_obj, Query(ismissing)) === (missing, missing, missing)
@benchmark getall($missings_obj, Query(ismissing))

println("setall")
# Need a wrapper so we don't have to pass in the starting iterator
setall(obj, q, vals)
@benchmark setall($obj, $q, $vals)
# using ProfileView
# @profview for i in 1:1000000 setall(obj, q, vals) end
@code_native setall(obj, q, vals)
@code_warntype setall(obj, q, vals)

# @btime Accessors.set($obj, $slowlens, $vals)
@test setall(obj, q, vals) ==
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=2.0,), [1]))

using Cthulhu
@descend getall(obj, q)
# using ProfileView
# @profview for i in 1:1000000 Accessors.set(obj, lens, vals) end
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=2.0,)), [1])

println("unstable set")
unstable_q = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x isa NamedTuple)
@btime setall($obj, $unstable_q, $vals)
# slow_unstable_lens = Accessors.Query(; select=x -> x isa Number && x > 4, optic=Properties())
Expand All @@ -50,10 +47,15 @@ unstable_q = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x
@btime modify(x -> 10x, $obj, $q)

# Context
obj = (b=2, c=2)
@test context((o, fn) -> fn, obj, q) == (:b, :c)
@test context((o, fn) -> typeof(o), obj, q) == (typeof(obj), typeof(obj))
@btime context((o, fn) -> fn, $obj, $q)
q = Query(;
select=x -> x isa Int,
descend=x -> x isa NamedTuple,
optic = Accessors.Properties()
)
obj2 = (1.0, :a, (b=2, c=2))
@test context((o, fn) -> fn, obj2, q) == (:b, :c)
@test context((o, fn) -> typeof(o), obj2, q) == (typeof(obj2[3]), typeof(obj2[3]))
@btime context((o, fn) -> fn, $obj2, $q)

# Macros
@test (@getall (x for x in missings_obj if x isa Number)) == (1, 2)
Expand Down

0 comments on commit 370e2fc

Please sign in to comment.