Skip to content

Commit

Permalink
Type stability fixes in NestedIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Oct 2, 2019
1 parent 37f9a0c commit 1192fd8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
44 changes: 21 additions & 23 deletions src/Containers/nested_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,49 +24,47 @@ for i1 in iterators[1]()
end
```
"""
struct NestedIterator{T}
struct NestedIterator{T, C}
iterators::T # Tuple of functions
condition::Function
condition::C
end
function nested(iterators...; condition = (args...) -> true)
return NestedIterator(iterators, condition)
end
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state)
function next_iterate(iterators, condition, elems, states, iterator, elem_state)
if elem_state === nothing
return nothing
end
elem, state = elem_state
elems_states = first_iterate(
it, i + 1, (elems..., elem),
Base.tail(iterators), condition, (elems..., elem),
(states..., (iterator, state, elem)))
if elems_states !== nothing
return elems_states
end
return next_iterate(it, i, elems, states, iterator, iterate(iterator, state))
return next_iterate(iterators, condition, elems, states, iterator, iterate(iterator, state))
end
function first_iterate(it::NestedIterator, i, elems, states)
if i > length(it.iterators)
if it.condition(elems...)
return elems, states
else
return nothing
end
end
iterator = it.iterators[i](elems...)
return next_iterate(it, i, elems, states, iterator, iterate(iterator))
end
function tail_iterate(it::NestedIterator, i, elems, states)
if i > length(it.iterators)
function first_iterate(::Tuple{}, condition, elems, states)
if condition(elems...)
return elems, states
else
return nothing
end
next = tail_iterate(it, i + 1, (elems..., states[i][3]), states)
end
function first_iterate(iterators, condition, elems, states)
iterator = iterators[1](elems...)
return next_iterate(iterators, condition, elems, states, iterator, iterate(iterator))
end
tail_iterate(::Tuple{}, condition, elems, states, prev_states) = nothing
function tail_iterate(iterators, condition, elems, states, prev_states)
next = tail_iterate(Base.tail(iterators), condition, (elems..., states[1][3]), Base.tail(states), (prev_states..., states[1]))
if next !== nothing
return next
end
iterator = states[i][1]
next_iterate(it, i, elems, states[1:(i - 1)], iterator, iterate(iterator, states[i][2]))
iterator = states[1][1]
next_iterate(iterators, condition, elems, prev_states, iterator, iterate(iterator, states[1][2]))
end
Base.iterate(it::NestedIterator) = first_iterate(it, 1, tuple(), tuple())
Base.iterate(it::NestedIterator, states) = tail_iterate(it, 1, tuple(), states)
Base.iterate(it::NestedIterator) = first_iterate(it.iterators, it.condition, tuple(), tuple())
Base.iterate(it::NestedIterator, states) = tail_iterate(it.iterators, it.condition, tuple(), states, tuple())
4 changes: 2 additions & 2 deletions test/perf/axis_constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function dense_axis_constraints(n)

@variable(model, x[1:n])
set = MOI.EqualTo(0.0)
@time @constraint(model, con_refs[i = 2:n], x[i] in set)
con_refs = @time @constraint(model, [i = 2:n], x[i] in set)
optimize!(model)
@assert sum_iterate(con_refs) == n - 1
@btime sum_iterate($con_refs)
Expand All @@ -47,7 +47,7 @@ function sparse_axis_constraints(n)

@variable(model, x[1:n])
set = MOI.EqualTo(0.0)
@time @constraint(model, con_refs[i = 1:n; iseven(i)], x[i] in set)
con_refs = @time @constraint(model, [i = 1:n; iseven(i)], x[i] in set)
optimize!(model)
@assert sum_iterate(con_refs) == div(n, 2)
@btime sum_iterate($con_refs)
Expand Down

0 comments on commit 1192fd8

Please sign in to comment.