Skip to content

Commit a8d02ef

Browse files
committed
add a throttle macro
1 parent 8d948e8 commit a8d02ef

File tree

1 file changed

+55
-7
lines changed

1 file changed

+55
-7
lines changed

src/utils.jl

+55-7
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,9 @@ end
530530
"""
531531
throttle(f, timeout; leading=true, trailing=false)
532532
533-
Return a function that when invoked, will only be triggered at most once
534-
during `timeout` seconds.
533+
Return a function that when called, will only call the given `f` at most
534+
once during `timeout` seconds. Any arguments passed to this new function
535+
are passed to `f`.
535536
536537
Normally, the throttled function will run as much as it can, without ever
537538
going more than once per `wait` duration; but if you'd like to disable the
@@ -540,17 +541,27 @@ the trailing edge, pass `trailing=true`.
540541
541542
# Examples
542543
```jldoctest
543-
julia> a = Flux.throttle(() -> println("Flux"), 2);
544+
julia> noarg = Flux.throttle(() -> println("Flux"), 2);
544545
545-
julia> for i = 1:4 # a called in alternate iterations
546-
a()
546+
julia> for i in 1:4
547+
noarg() # println called in alternate iterations
547548
sleep(1)
548549
end
549550
Flux
550551
Flux
552+
553+
julia> onearg = Flux.throttle(i -> println("step = ", i), 1);
554+
555+
julia> for i in 1:10
556+
onearg(i)
557+
sleep(0.3)
558+
end
559+
step = 1
560+
step = 5
561+
step = 9
551562
```
552563
"""
553-
function throttle(f, timeout; leading=true, trailing=false)
564+
function throttle(f, timeout::Real; leading=true, trailing=false)
554565
cooldown = true
555566
later = nothing
556567
result = nothing
@@ -582,6 +593,44 @@ function throttle(f, timeout; leading=true, trailing=false)
582593
end
583594
end
584595

596+
"""
597+
@throttle timeout expr
598+
599+
Evaluates the given expression at most once every `timeout` seconds.
600+
601+
Internally, it uses [`throttle`](@ref Flux.throttle). But instead of
602+
defining a function outside the loop, it lets you place the code inside
603+
the loop.
604+
605+
# Example
606+
```jldoctest
607+
julia> for i in 1:20
608+
j = 100i
609+
sleep(0.2)
610+
Flux.@throttle 0.9 if iseven(i)
611+
println("i = ", i, ", and j = ", j)
612+
else
613+
println("i = ", i)
614+
end
615+
end
616+
i = 1
617+
i = 6, and j = 600
618+
i = 11
619+
i = 16, and j = 1600
620+
```
621+
"""
622+
macro throttle(timeout::Real, ex)
623+
expr = macroexpand(__module__, ex)
624+
vars = unique(_allsymbols(expr))
625+
@gensym fast slow
626+
Base.eval(__module__, :($fast($(vars...)) = $expr))
627+
Base.eval(__module__, :(const $slow = $throttle($fast, $timeout)))
628+
:($slow($(vars...))) |> esc
629+
end
630+
631+
_allsymbols(s::Symbol) =[s]
632+
_allsymbols(other) = Symbol[]
633+
_allsymbols(ex::Expr) = vcat(_allsymbols.(ex.args)...)
585634

586635
"""
587636
modules(m)
@@ -651,7 +700,6 @@ julia> loss() = rand();
651700
652701
julia> trigger = Flux.patience(() -> loss() < 1, 3);
653702
654-
655703
julia> for i in 1:10
656704
@info "Epoch \$i"
657705
trigger() && break

0 commit comments

Comments
 (0)