diff --git a/.gitignore b/.gitignore index b067edd..976666b 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /Manifest.toml +/logs diff --git a/Oolong.toml b/Oolong.toml new file mode 100644 index 0000000..6b577dc --- /dev/null +++ b/Oolong.toml @@ -0,0 +1,13 @@ +banner = true +color = true + +[logging] +log_level = "Debug" +date_format = "yyyy-mm-ddTHH:MM:SS.s" + + [logging.driver_logger] + console_logger.is_expand_stack_trace = true + rotating_logger.path = "./logs" + rotating_logger.file_format = "YYYY-mm-dd.\\l\\o\\g" + + [logging.loki_logger] \ No newline at end of file diff --git a/Project.toml b/Project.toml index 8c8270e..a17f3b2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,19 @@ name = "Oolong" uuid = "c9dcc2fc-6356-41de-aa29-480ea90c21cd" authors = ["Jun Tian and contributors"] -version = "0.1.0" +version = "0.0.1" [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +GarishPrint = "b0ab02a7-8576-43f7-aa76-eaa7c3897c54" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" +LokiLogger = "51d429d1-9683-4c89-86d7-889f440454ef" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" [compat] julia = "1" diff --git a/README.md b/README.md index 07e3de9..e23dbe8 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,28 @@ -# Oolong.jl +
+Oolong.jl logo
+  ____        _                     |  > 是非成败转头空
+ / __ \      | |                    |  > Success or failure,
+| |  | | ___ | | ___  _ __   __ _   |  > right or wrong,
+| |  | |/ _ \| |/ _ \| '_ \ / _` |  |  > all turn out vain.
+| |__| | (_) | | (_) | | | | (_) |  |
+ \____/ \___/|_|\___/|_| |_|\__, |  |  The Immortals by the River 
+                             __/ |  |  -- Yang Shen 
+                            |___/   |  (Translated by Xu Yuanchong) 
+
-*An actor framework for [ReinforcementLearning.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl)* +**Oolong.jl** is a framework for building scalable distributed applications in Julia. -> “是非成败转头空” —— [《临江仙》](https://www.vincentpoon.com/the-immortals-by-the-river-----------------.html) -> [杨慎](https://zh.wikipedia.org/zh-hans/%E6%9D%A8%E6%85%8E) -> -> "Success or failure, right or wrong, all turn out vain." - [*The Immortals by -> the -> River*](https://www.vincentpoon.com/the-immortals-by-the-river-----------------.html), -> [Yang Shen](https://en.wikipedia.org/wiki/Yang_Shen) -> -> (Translated by [Xu Yuanchong](https://en.wikipedia.org/wiki/Xu_Yuanchong)) +## Features -## Roadmap +- Easy to use + Only very minimal APIs are exposed to make this package easy to use (yes, easier than [Distributed.jl](https://docs.julialang.org/en/v1/stdlib/Distributed/)). + +- Non-invasive + Users can easily extend existing packages to apply them in a cluster. -- [x] Figure out a set of simple primitives for running distributed - applications. -- [ ] Apply this package to some typical RL algorithms: - - [x] Parameter server - - [x] Batch serving - - [ ] Add macro to expose a http endpoint - - [ ] A3C - - [ ] D4PG - - [ ] AlphaZero - - [ ] Deep CFR - - [ ] NFSP - - [ ] Evolution algorithms -- [ ] Resource management across nodes -- [ ] State persistence and fault tolerance -- [ ] Configurable logging and dashboard - - [LokiLogger.jl](https://github.com/fredrikekre/LokiLogger.jl) - - [Stipple.jl](https://github.com/GenieFramework/Stipple.jl) +- Fault tolerance + +- Auto scaling ## Get Started @@ -44,86 +36,114 @@ pkg> activate --temp pkg> add https://github.com/JuliaReinforcementLearning/Oolong.jl ``` -`Oolong.jl` adopts the [actor model](https://en.wikipedia.org/wiki/Actor_model) to -parallelize your existing code. One of the core APIs defined in this package is -the `@actor` macro. +See tests for some example usages. (TODO: move typical examples here when APIs are stabled) -```julia -using Oolong +## Examples -A = @actor () -> @info "Hello World" -``` +- Batch evaluation. +- AlphaZero +- Parameter server +- Parameter search -By putting the `@actor` macro before arbitrary callable object, we defined an -**actor**. And we can call it as usual: +Please contact us if you have a concrete scenario but not sure how to use this package! -```julia -A(); -``` +## Deployment -You'll see something like this on your screen: +### Local Machines -``` -Info:[2021-06-30 22:59:51](@/user/#1)Hello World -``` - -Next, let's make sure anonymous functions with positional and keyword arguments -can also work as expected: +### K8S -```julia -A = @actor (msg;suffix="!") -> @info "Hello " * msg * suffix -A("World";suffix="!!!") -# Info:[2021-06-30 23:00:38](@/user/#5)Hello World!!! -``` +## Roadmap -For some functions, we are more interested in the returned value. +1. Stage 1 + 1. Stabilize API + 1. ☑️ `p::PotID = @pot tea [prop=value...]`, define a container over any callable object. + 2. ☑️ `(p::PotID)(args...;kw...)`, which behaves just like `tea(args...;kw...)`, except that it's an async call, at most once delievery, a `Promise` is returned. + 3. ☑️ `msg |> p::PotID` similar to the above one, except that nothing is returned. + 4. ☑️ `(p::PotID).prop`, async call, at most once delievery, return the `prop` of the inner `tea`. + 5. 🧐 `-->`, `<--`, define a streaming pipeline. + 6. 🧐 timed wait on `Promise`. + 2. Features + 1. ☑️ Logging. All messages are sent to primary node by default. + 2. 🧐 RemoteREPL + 3. ☑️ CPU/GPU allocation + 4. 🧐 Auto intall+using dependencies + 5. ☑️ Global configuration + 6. 🧐 Close pot when it is idle for a period + 3. Example usages + 1. 🧐 Parameter search + 2. 🧐 Batch evaluation. + 3. 🧐 AlphaZero + 4. 🧐 Parameter server +2. Stage 2 + 1. Auto1.scaling. Allow workers join/exit? + 1. 🧐 Custom cluster manager + 2. Dashboard + 1. 🧐 [grafana](https://grafana.com/) + 3. Custom Logger + 1. ☑️ [LokiLogger.jl](https://github.com/fredrikekre/LokiLogger.jl) + 2. 🧐 [Stipple.jl](https://github.com/GenieFramework/Stipple.jl) + 4. Tracing + 1. [opentelemetry](https://opentelemetry.io/) +1. Stage 3 + 1. Drop out Distributed.jl? + 1. 🧐 `Future` will transfer the ownership of the underlying data to the caller. Not very efficient when the data is passed back and forth several times in its life circle. + 2. 🧐 differentiate across pots? + 3. 🧐 Python client (transpile, pickle) + 4. 🧐 K8S + 5. 🧐 JuliaHub + 6. 🧐 AWS + 7. 🧐 Azure + +## Design + +### Workflow -```julia -A = @actor msg -> "Hello " * msg -res = A("World") ``` - -Well, different from the general function call, a result similar to `Future` is -returned instead of the real value. We can then fetch the result with the -following syntax: - -```julia -res[] -# "Hello World" + +--------+ + | Flavor | + +--------+ + | + V +-------------+ + +---+---+ | Pot | + | PotID |<===>| | + +---+---+ | PotID | + | | () -> Tea | + | | require | + | +-------------+ + +-------|-------------------------+ + | V boiled somewhere | + | +----+----+ | + | | Channel | | + | +----+----+ | + | | | + | V +-----------+ | + | +--+--+ | PotState | | + | | Tea |<===>| | | + | +--+--+ | Children | | + | | +-----------+ | + | V | + | +----+----+ | + | | Promise | | + | +---------+ | + +---------------------------------+ ``` -To maintain the internal states across different calls, we can also apply `@actor` -to a customized structure: +A `Pot` is mainly a container of an arbitrary object (`tea`) which is instantiated by calling a parameterless function. Whenever a `Pot` receives a `flavor`, the water in the `Pot` is *boiled* first (a `task` is created to process `tea` and `flavor`) if it is cool (the previous `task` was exited by accident or on demand). Some `Pot`s may have a few specific `require`ments (the number of cpu, gpu). If those requirements can not be satisfied, the `Pot` will be pending to wait for new resources. Users can define how `tea` and `flavor` are processed through multiple dispatch on `process(tea, flavor)`. In some `task`s, users may create many other `Pot`s whose references (`PotID`) are stored in `Children`. A `PotID` is simply a path used to locate a `Pot`. -```julia -Base.@kwdef mutable struct Counter - n::Int = 0 -end - -(c::Counter)() = c.n += 1 +### Decisions -A = @actor Counter() +The following design decisions need to be reviewed continuously. -for _ in 1:10 - A() -end +1. Each `Pot` can only be created inside of another `Pot`, which forms a child-parent relation. If no `Pot` is found in the `current_task()`, the parent is bind to `/user` by default. When registering a new `Pot` whose`PotID` is already registerred. The old one will be removed first. This will allow updating `Pot`s dynamically. (Do we really need this feature?) -n = A.n - -n[] -# 10 -``` - -Note that similar to function call, the return of `A.n` is also a `Future` like object. - -### Tips - -- Be careful with `self()` +### FAQ ## Acknowledgement -This package is mainly inspired by the following packages: +This package is mainly inspired by the following projects: -- [Actors.jl](https://github.com/JuliaActors/Actors.jl) +- [Orleans](https://github.com/dotnet/orleans) - [Proto.Actor](https://proto.actor/) - [Ray](https://ray.io/) +- [Actors.jl](https://github.com/JuliaActors/Actors.jl) diff --git a/docs/logo.jl b/docs/logo.jl new file mode 100644 index 0000000..5384962 --- /dev/null +++ b/docs/logo.jl @@ -0,0 +1,42 @@ +using Luxor + +scale = 2 +ratio = (√5 -1)/2 + +w, h = scale * 128 * 2, scale * 128 * 2 + +r1 = w/2 +r2 = r1 * ratio +r3 = r2 * ratio +r4 = r3 * ratio + +c1 = Point(0, 0) +c2 = c1 + Point(r1-r2, r1-r2) * √2/2 +c3 = c1 + Point(r1-r3, r1-r3) * √2/2 +c4 = c1 + Point(r1-r4, r1-r4) * √2/2 + +Drawing(w, h, "logo.svg") +background(1, 1, 1, 0) +Luxor.origin() + +setcolor(1,1,1) +circle(c1, r1, :fill) +setcolor(0.251, 0.388, 0.847) # dark blue +circle(c1, r1-4*scale, :fill) + +setcolor(1,1,1) +circle(c2, r2, :fill) +setcolor(0.796, 0.235, 0.2) # dark red +circle(c2, r2-4*scale, :fill) + +setcolor(1,1,1) +circle(c3, r3, :fill) +setcolor(0.22, 0.596, 0.149) # dark green +circle(c3, r3-4*scale, :fill) + +setcolor(1,1,1) +circle(c4, r4, :fill) +setcolor(0.584, 0.345, 0.698) # dark purple +circle(c4, r4-4*scale, :fill) + +finish() \ No newline at end of file diff --git a/docs/logo.png b/docs/logo.png new file mode 100644 index 0000000..6dd6470 Binary files /dev/null and b/docs/logo.png differ diff --git a/docs/logo.svg b/docs/logo.svg new file mode 100644 index 0000000..1403332 --- /dev/null +++ b/docs/logo.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/src/Oolong.jl b/src/Oolong.jl index 1fb3a9c..346147e 100644 --- a/src/Oolong.jl +++ b/src/Oolong.jl @@ -3,12 +3,10 @@ module Oolong const OL = Oolong export OL -include("core.jl") -include("parameter_server.jl") -include("serve.jl") - -function __init__() - init() -end +include("config.jl") +include("logging.jl") +include("base.jl") +include("core/core.jl") +include("start.jl") end diff --git a/src/base.jl b/src/base.jl new file mode 100644 index 0000000..bf3c201 --- /dev/null +++ b/src/base.jl @@ -0,0 +1,158 @@ +const KEY = :OOLONG + +""" +Similar to `Future`, but we added some customized methods. +""" +struct Promise + f::Future + function Promise(args...) + new(Future(args...)) + end +end + +Base.getindex(p::Promise) = getindex(p.f) +Base.wait(p::Promise) = wait(p.f) + +"Recursively fetch inner value" +function Base.getindex(p::Promise, ::typeof(!)) + x = p.f[] + while x isa Promise || x isa Future + x = x[] + end + x +end + +function Base.getindex(ps::Vector{Promise}) + res = Vector(undef, length(ps)) + @sync for (i, p) in enumerate(ps) + Threads.@spawn begin + res[i] = p[] + end + end + res +end + +struct TimeOutError{T} <: Exception + t::T +end + +Base.showerror(io::IO, err::TimeOutError) = print(io, "failed to complete in $(err.t) seconds") + +""" + p::Promise[t::Number] + +Try to fetch value during a period of `t`. +A [`TimeOutError`](@ref) is thrown if the underlying data is still not ready after `t`. +""" +function Base.getindex(p::Promise, t::Number, pollint=0.1) + res = timedwait(t;pollint=pollint) do + isready(p) + end + if res === :ok + p[] + else + throw(TimeOutError(t)) + end +end + +Base.put!(p::Promise, x) = put!(p.f, x) +Base.isready(p::Promise) = isready(p.f) + +##### + +struct PotNotRegisteredError <: Exception + pid::PotID +end + +Base.showerror(io::IO, err::PotNotRegisteredError) = print(io, "can not find any pot associated with the pid: $(err.pid)") + +##### + +const RESOURCE_REGISTRY = Dict{Symbol, UInt}( + :cpu => () -> Threads.nthreads(), + :gpu => () -> length(CUDA.devices()) +) + +struct ResourceInfo{I<:NamedTuple} + info::I +end + +ResourceInfo() = ResourceInfo(NamedTuple(k=>v() for (k,v) in RESOURCE_REGISTRY)) + +ResourceInfo(;kw...) = ResourceInfo(kw.data) +Base.keys(r::ResourceInfo) = keys(r.info) +Base.getindex(r::ResourceInfo, x) = getindex(r.info, x) +Base.haskey(r::ResourceInfo, x) = haskey(r.info, x) + +function Base.:(<=)(x::ResourceInfo, y::ResourceInfo) + le = true + for k in keys(x) + if haskey(y, k) && x[k] <= y[k] + continue + else + le = false + break + end + end + le +end + +function Base.:(-)(x::ResourceInfo, y::ResourceInfo) + merge(x, (k => x[k]-v for (k,v) in pairs(y))) +end + +struct RequirementNotSatisfiedError <: Exception + required::ResourceInfo + remaining::ResourceInfo +end + +Base.showerror(io::IO, err::RequirementNotSatisfiedError) = print(io, "required: $(err.required), remaining: $(err.remaining)") + +##### + +"System level messages are processed immediately" +abstract type AbstractSysMsg end + +is_prioritized(msg) = false +is_prioritized(msg::AbstractSysMsg) = true + +# !!! force system level messages to be executed immediately +# directly copied from +# https://github.com/JuliaLang/julia/blob/6aaedecc447e3d8226d5027fb13d0c3cbfbfea2a/base/channels.jl#L13-L31 +# with minor modification +function Base.put_buffered(c::Channel, v) + lock(c) + try + while length(c.data) == c.sz_max + Base.check_channel_state(c) + wait(c.cond_put) + end + if is_prioritized(v) + pushfirst!(c.data, v) # !!! force sys msg to be handled immediately + else + push!(c.data, v) # !!! force sys msg to be handled immediately + end + # notify all, since some of the waiters may be on a "fetch" call. + notify(c.cond_take, nothing, true, false) + finally + unlock(c) + end + return v +end + +##### + +"Similar to `RemoteException`, except that we need the `PotID` info." +struct Failure <: Exception + pid::PotID + captured::CapturedException +end + +Failure(captured) = Failure(self(), captured) + +is_prioritized(::Failure) = true + +function Base.showerror(io::IO, f::Failure) + println(io, "In pot $(f.pid) :") + showerror(io, re.captured) +end diff --git a/src/config.jl b/src/config.jl new file mode 100644 index 0000000..63754c3 --- /dev/null +++ b/src/config.jl @@ -0,0 +1,38 @@ +using Configurations +using YAML +using Logging +using Dates + +@option struct ConsoleLoggerConfig + is_expand_stack_trace::Bool = true +end + +@option struct RotatingLoggerConfig + path::String = "logs" + file_format::String = raw"YYYY-mm-dd.\l\o\g" +end + +@option struct DriverLoggerConfig + console_logger::Union{ConsoleLoggerConfig, Nothing}=ConsoleLoggerConfig() + rotating_logger::Union{RotatingLoggerConfig, Nothing}=RotatingLoggerConfig() +end + +@option struct LokiLoggerConfig + url::String = "http://127.0.0.1:3100" +end + +@option struct LoggingConfig + # filter + log_level::String = "Info" + # transformer + date_format::String="yyyy-mm-ddTHH:MM:SS.s" + # sink + driver_logger::Union{DriverLoggerConfig, Nothing} = DriverLoggerConfig() + loki_logger::Union{LokiLoggerConfig, Nothing} = nothing +end + +@option struct Config + banner::Bool = Base.JLOptions().banner != 0 + color::Bool = Base.have_color + logging::LoggingConfig = LoggingConfig() +end \ No newline at end of file diff --git a/src/core.jl b/src/core.jl deleted file mode 100644 index 946e9fd..0000000 --- a/src/core.jl +++ /dev/null @@ -1,507 +0,0 @@ -export @actor - -using Base.Threads -using Distributed -using Dates -using Logging - -const ACTOR_KEY = "OOLONG" - -##### -# System Messages -##### - -abstract type AbstractSysMsg end - -struct SuccessMsg{M} <: AbstractSysMsg - msg::M -end - -struct FailureMsg{R} <: AbstractSysMsg - reason::R -end - -Base.getindex(msg::FailureMsg) = msg.reason - -Base.@kwdef struct StartMsg{F} <: AbstractSysMsg - info::F = nothing -end - -struct StopMsg{R} <: AbstractSysMsg - reason::R -end - -struct RestartMsg <: AbstractSysMsg end -struct PreRestartMsg <: AbstractSysMsg end -struct PostRestartMsg <: AbstractSysMsg end -struct ResumeMsg <: AbstractSysMsg end - -struct StatMsg <: AbstractSysMsg end - -struct FutureWrapper - f::Future - FutureWrapper(args...) = new(Future(args...)) -end - -function Base.getindex(f::FutureWrapper) - res = getindex(f.f) - if res isa SuccessMsg - res.msg - elseif res isa FailureMsg{<:Exception} - throw(res.reason) - else - res - end -end - -Base.put!(f::FutureWrapper, x) = put!(f.f, x) - -##### -# Mailbox -##### - -struct Mailbox - ch::RemoteChannel -end - -const DEFAULT_MAILBOX_SIZE = typemax(Int) - -Mailbox(; size=DEFAULT_MAILBOX_SIZE, pid=myid()) = Mailbox(RemoteChannel(() -> Channel(size), pid)) - -Base.take!(m::Mailbox) = take!(getfield(m, :ch)) - -Base.put!(m::Mailbox, msg) = put!(getfield(m, :ch), msg) - -whereis(m::Mailbox) = getfield(m, :ch).where - -#= -Actor Hierarchy - -NOBODY - └── ROOT - ├── LOGGER - ├── SCHEDULER - | ├── WORKER_1 - | ├── ... - | └── WORKER_N - └── USER - ├── foo - └── bar - └── baz -=# - -struct NoBody end -const NOBODY = NoBody() - -struct RootActor end -const ROOT_ACTOR = RootActor() -const ROOT = Ref{Mailbox}() - -struct SchedulerActor end -const SCHEDULER_ACTOR = SchedulerActor() -const SCHEDULER = Ref{Mailbox}() - -struct SchedulerWorker -end - -struct StagingActor end - -struct UserActor end -const USER_ACTOR = UserActor() -const USER = Ref{Mailbox}() - -struct LoggerActor end -const LOGGER_ACTOR = LoggerActor() -const LOGGER = Ref{Mailbox}() - -##### -# RemoteLogger -##### - -struct RemoteLogger <: AbstractLogger - mailbox - min_level -end - -struct LogMsg - args - kwargs -end - -const DATE_FORMAT = "yyyy-mm-dd HH:MM:SS" - -function Logging.handle_message(logger::RemoteLogger, args...; kwargs...) - kwargs = merge(kwargs.data,( - datetime="$(Dates.format(now(), DATE_FORMAT))", - path=_self().path - )) - logger.mailbox[LogMsg(args, kwargs)] -end - -Logging.shouldlog(::RemoteLogger, args...) = true -Logging.min_enabled_level(L::RemoteLogger) = L.min_level - -##### -# Actor -##### - -Base.@kwdef struct Actor - path::String - thunk::Any - owner::Union{NoBody,Mailbox} - children::Dict{String,Mailbox} - taskref::Ref{Task} - mailbox::Ref{Mailbox} - mailbox_size::Int -end - -Base.nameof(a::Actor) = basename(a.path) - -function Actor( - thunk; - owner=self(), - children=Dict{String,Mailbox}(), - name=string(nameof(thunk)), - path=(isnothing(_self()) ? "/user" : _self().name) * "/" * name, - mailbox=nothing, - mailbox_size=DEFAULT_MAILBOX_SIZE, -) - return Actor( - path, - thunk, - owner, - children, - Ref{Task}(), - isnothing(mailbox) ? Ref{Mailbox}() : Ref{Mailbox}(mailbox), - mailbox_size - ) -end - -function act(A) - logger = isassigned(LOGGER) ? RemoteLogger(LOGGER[], Logging.Debug) : global_logger() - with_logger(logger) do - handler = A.thunk() - while true - try - msg = take!(A.mailbox[]) - handle(handler, msg) - msg isa StopMsg && break - catch exec - @error exec - for (exc, bt) in Base.catch_stack() - showerror(stdout, exc, bt) - println(stdout) - end - action = A.owner(FailureMsg(exec))[] - if action isa ResumeMsg - handle(handler, action) - continue - elseif action isa StopMsg - handle(handler, action) - rethrow() - elseif action isa RestartMsg - handle(handler, PreRestartMsg()) - handler = A.thunk() - handle(handler, PostRestartMsg()) - else - @error "unknown msg received from $(dirname(nameof(A))): $exec" - rethrow() - end - end - end - end -end - -""" -Get the [`Mailbox`](@ref) in the current task. - -!!! note - `self()` in the REPL is bind to `USER`. -""" -function self() - A = _self() - return isnothing(A) ? USER[] : A.mailbox[] -end - -function _self() - try - task_local_storage(ACTOR_KEY) - catch ex - if ex isa KeyError - nothing - else - rethrow() - end - end -end - -function _schedule(A::Actor) - if !isassigned(A.mailbox) - A.mailbox[] = Mailbox(;size=A.mailbox_size) - end - A.taskref[] = Threads.@spawn begin - task_local_storage(ACTOR_KEY, A) - act(A) - end - return A.mailbox[] -end - -struct ScheduleMsg <: AbstractSysMsg - actor::Actor -end - -function Base.schedule(A::Actor) - s = _self() - if isnothing(s) - if A.owner === NOBODY - _schedule(A) - else - # the actor is submitted from REPL - # we schedule the actor through USER so that it will be bind to USER - USER[](ScheduleMsg(A))[] - end - else - if A.owner === ROOT[] - mailbox = _schedule(A) - else - mailbox = SCHEDULER[](ScheduleMsg(A))[] - end - s.children[nameof(A)] = mailbox - mailbox - end -end - - -macro actor(exs...) - a = exs[1] - name = if a isa Symbol - string(a) - elseif a isa Expr && a.head == :call - string(a.args[1]) - else - nothing - end - - default_kw = isnothing(name) ? (;) : (;name=name) - thunk = esc(:(() -> ($(a)))) - kwargs = [esc(x) for x in exs[2:end]] - kw = :(merge($default_kw, (;$(kwargs...)))) - - quote - schedule(Actor($thunk; $kw...)) - end -end - -##### -# System Behaviors -##### -function handle(x, args...;kwargs...) - x(args...;kwargs...) -end - -function handle(x, ::FailureMsg) - RestartMsg() -end - -function handle(x, ::PreRestartMsg) - @debug "stopping children before restart" - handle(x, StopMsg("stop before restart")) -end - -function handle(x, ::PostRestartMsg) - @debug "starting after restart signal" - handle(x, StartMsg(:restart)) -end - -function handle(x, msg::StartMsg) - @debug "start msg received" -end - -function handle(x, msg::StopMsg) - for c in values(_self().children) - c(msg)[] # ??? blocking - end -end - -struct ActorStat - path::String -end - -function handle(x, ::StatMsg) - s = _self() - ActorStat( - s.path - ) -end - -Base.stat(m::Mailbox) = m(StatMsg())[] -Base.pathof(m::Mailbox) = stat(m).path -Base.nameof(m::Mailbox) = basename(pathof(m)) - -function handle(::RootActor, s::StartMsg) - @info "$(@__MODULE__) starting..." - LOGGER[] = @actor LOGGER_ACTOR path="/logger" - LOGGER[](s)[] # blocking to ensure LOGGER has started - SCHEDULER[] = @actor SCHEDULER_ACTOR path="/scheduler" - SCHEDULER[](s)[] # blocking to ensure SCHEDULER has started - USER[] = @actor USER_ACTOR path = "/user" - USER[](s)[] # blocking to ensure USER has started -end - -function handle(::LoggerActor, ::StartMsg) - @info "LOGGER started" -end - -function handle(L::LoggerActor, msg::LogMsg) - buf = IOBuffer() - iob = IOContext(buf, stderr) - - level, message, _module, group, id, file, line = msg.args - - color, prefix, suffix = Logging.default_metafmt( - level, _module, group, id, file, line - ) - printstyled(iob, prefix; bold=true, color=color) - printstyled(iob, "[$(msg.kwargs.datetime)]"; color=:light_black) - printstyled(iob, "(@$(msg.kwargs.path))"; color=:green) - print(iob, message) - for (k,v) in pairs(msg.kwargs) - if k ∉ (:datetime, :path) - print(iob, " ") - printstyled(iob, k; color=:yellow) - print(iob, "=") - print(iob, v) - end - end - !isempty(suffix) && printstyled(iob, "($suffix)"; color=:light_black) - println(iob) - write(stderr, take!(buf)) -end - -function handle(::SchedulerActor, ::StartMsg) - @info "SCHEDULER started" -end - -function handle(::SchedulerActor, msg::ScheduleMsg) - # TODO: schedule it smartly based on workers' status - @debug "scheduling $(nameof(msg.actor))" - _schedule(msg.actor) -end - -function handle(::UserActor, ::StartMsg) - @info "USER started" -end - -function handle(::UserActor, s::ScheduleMsg) - mailbox = SCHEDULER[](s)[] - _self().children[nameof(s.actor)] = mailbox -end - -##### -# Syntax Sugar -##### - -##### - -struct CallMsg{T} - args::T - kwargs - value_box -end - -function handle(x, c::CallMsg) - try - put!(c.value_box, SuccessMsg(handle(x, c.args...; c.kwargs...))) - catch exec - put!(c.value_box, FailureMsg(exec)) - rethrow() - end - nothing -end - -function (m::Mailbox)(args...;kwargs...) - value_box = FutureWrapper(whereis(m)) - msg = CallMsg(args, kwargs, value_box) - put!(m, msg) - value_box -end - -##### - -struct CastMsg{T} - args::T -end - -handle(x, c::CastMsg) = handle(x, c.args...) - -function Base.getindex(m::Mailbox, args...) - put!(m, CastMsg(args)) - nothing -end - -##### - -struct GetPropMsg - name::Symbol - value_box::FutureWrapper -end - -handle(x, p::GetPropMsg) = put!(p.value_box, getproperty(x, p.name)) - -function Base.getproperty(m::Mailbox, name::Symbol) - res = FutureWrapper(whereis(m)) - put!(m, GetPropMsg(name, res)) - res -end - -##### - -Base.@kwdef struct RequestMsg{M} - msg::M - from::Mailbox = self() -end - -Base.@kwdef struct ReplyMsg{M} - msg::M - from::Mailbox = self() -end - -req(x::Mailbox, msg) = put!(x, RequestMsg(msg=msg)) -rep(x::Mailbox, msg) = put!(x, ReplyMsg(msg=msg)) -async_req(x::Mailbox, msg) = Threads.@spawn put!(x, RequestMsg(msg=msg)) -async_rep(x::Mailbox, msg) = Threads.@spawn put!(x, ReplyMsg(msg=msg)) - -handle(x, req::RequestMsg) = rep(req.from, handle(x, req.msg)) - -# !!! force system level messages to be executed immediately -# directly copied from -# https://github.com/JuliaLang/julia/blob/6aaedecc447e3d8226d5027fb13d0c3cbfbfea2a/base/channels.jl#L13-L31 -# with minor modification -function Base.put_buffered( - c::Channel, - v::Union{ - AbstractSysMsg, - CallMsg{<:Tuple{<:AbstractSysMsg}}, - CastMsg{<:Tuple{<:AbstractSysMsg}} - } -) - lock(c) - try - while length(c.data) == c.sz_max - Base.check_channel_state(c) - wait(c.cond_put) - end - pushfirst!(c.data, v) # !!! force sys msg to be handled immediately - # notify all, since some of the waiters may be on a "fetch" call. - notify(c.cond_take, nothing, true, false) - finally - unlock(c) - end - return v -end - -# !!! This should be called ONLY once -function init() - ROOT[] = @actor ROOT_ACTOR owner=NOBODY path="/" - ROOT[](StartMsg(nothing))[] # blocking is required -end diff --git a/src/core/core.jl b/src/core/core.jl new file mode 100644 index 0000000..20dc60b --- /dev/null +++ b/src/core/core.jl @@ -0,0 +1,4 @@ +include("pot_id.jl") +include("pot.jl") +include("message_handling.jl") +include("scheduling.jl") \ No newline at end of file diff --git a/src/core/message_handling.jl b/src/core/message_handling.jl new file mode 100644 index 0000000..51ce506 --- /dev/null +++ b/src/core/message_handling.jl @@ -0,0 +1,136 @@ +process(tea, args...;kw...) = tea(args...;kw...) + +function Base.put!(p::PotID, flavor) + try + put!(p[], flavor) + catch e + # TODO add test + if e isa PotNotRegisteredError + rethrow(e) + else + @error e + boil(p) + put!(p, flavor) + end + end +end + +##### + +struct CallMsg{A} + args::A + kw + promise +end + +is_prioritized(::CallMsg{<:Tuple{<:AbstractSysMsg}}) = true + +function (p::PotID)(args...;kw...) + promise = Promise(whereis(p)) # !!! the result should reside in the same place + put!(p, CallMsg(args, kw.data, promise)) + promise +end + +# ??? non specialized tea? +function process(tea, msg::CallMsg) + try + res = process(tea, msg.args...;msg.kw...) + put!(msg.promise, res) + catch err + ce = CapturedException(err, catch_backtrace()) + put!(msg.promise, Failure(ce)) + rethrow(err) + end +end + +##### + +function Base.:(|>)(x, p::PotID) + put!(p, x) + nothing +end + +##### + +struct GetPropMsg + prop::Symbol +end + +Base.getproperty(p::PotID, prop::Symbol) = p(GetPropMsg(prop)) + +process(tea, msg::GetPropMsg) = getproperty(tea, msg.prop) + +##### SysMsg + +struct Exit +end + +const EXIT = Exit() + +""" + CloseWhenIdleMsg(t::Int) + +Signal a Pot and its children to close the channel and release claimed resources if the Pot has been idle for `t` seconds. +""" +struct CloseWhenIdleMsg <: AbstractSysMsg + t::Int +end + +function process(tea, msg::CloseWhenIdleMsg) + t_idle = (now() - _self().last_update) / Millisecond(1_000) + if t_idle >= msg.t && isempty(_self().ch) + for c in children() + msg |> c + end + EXIT + end +end + +##### + +""" +Close the active channel and remove the registered `Pot`. +""" +struct RemoveMsg <: AbstractSysMsg +end + +Base.rm(p::PotID) = p(RemoveMsg()) + +function process(tea, msg::RemoveMsg) + # !!! note the order + for c in children() + c(msg)[] + end + unregister(self()) + close(_self().ch) + EXIT +end + +##### + +struct ResumeMsg <: AbstractSysMsg +end + +process(tea, ::ResumeMsg) = nothing + +##### + +struct RestartMsg <: AbstractSysMsg +end + +const RESTART = RestartMsg() + +process(tea, ::RestartMsg) = RESTART + +struct PreRestartMsg <: AbstractSysMsg +end + +process(tea, ::PreRestartMsg) = nothing + +struct PostRestartMsg <: AbstractSysMsg +end + +process(tea, ::PostRestartMsg) = nothing + +process(tea, ::Failure) = RESTART + diff --git a/src/core/pot.jl b/src/core/pot.jl new file mode 100644 index 0000000..bacab5a --- /dev/null +++ b/src/core/pot.jl @@ -0,0 +1,47 @@ +struct Pot + tea_bag::Any + pid::PotID + require::ResourceInfo + logger::Any +end + +function Pot( + tea_bag; + name=string(uuid4()), + cpu=eps(), + gpu=0, + logger=DEFAULT_LOGGER +) + pid = name isa PotID ? name : PotID(name) + require = ResourceInfo(cpu, gpu) + Pot(tea_bag, pid, require, logger) +end + +macro pot(tea, kw...) + tea_bag = esc(:(() -> ($(tea)))) + xs = [esc(x) for x in kw] + quote + p = Pot($tea_bag; $(xs...)) + register(p) + p.pid + end +end + +mutable struct PotState + pid::PotID + ch::Channel + create_time::DateTime + last_update::DateTime + n_processed::UInt +end + +_self() = get!(task_local_storage(), KEY, PotState(USER, current_task())) +self() = _self().pid + +local_scheduler() = SCHEDULER/"local_scheduler_$(myid())" + +Base.parent() = parent(self()) +Base.parent(p::PotID) = PotID(getfield(p, :path[1:end-1])) + +children() = children(self()) + diff --git a/src/core/pot_id.jl b/src/core/pot_id.jl new file mode 100644 index 0000000..e28b6bb --- /dev/null +++ b/src/core/pot_id.jl @@ -0,0 +1,49 @@ +export @P_str + +struct PotID + path::Tuple{Vararg{Symbol}} +end + +""" + P"[/]your/pot/path" + +The path can be either relative or absolute path. If a relative path is provided, it will be resolved to an absolute path based on the current context. + +!!! note + We don't validate the path for you during construction. A [`PotNotRegisteredError`](@ref) will be thrown when you try to send messages to an unregistered path. +""" +macro P_str(s) + PotID(s) +end + +function Base.show(io::IO, p::PotID) + if isempty(getfield(p, :path)) + print(io, "/") + else + for x in getfield(p, :path) + print(io, '/') + print(io, x) + end + end +end + +function PotID(s::String) + if length(s) > 0 + if s[1] == '/' + PotID(Tuple(Symbol(x) for x in split(s, '/';keepempty=false))) + else + self() / s + end + else + PotID(()) + end +end + +function Base.:(/)(p::PotID, s::String) + PotID((getfield(p, :path)..., Symbol(s))) +end + +const ROOT = P"/" +const LOGGER = P"/log" +const SCHEDULER = P"/scheduler" +const USER = P"/user" diff --git a/src/core/scheduling.jl b/src/core/scheduling.jl new file mode 100644 index 0000000..e62daa5 --- /dev/null +++ b/src/core/scheduling.jl @@ -0,0 +1,272 @@ +# TODO: set ttl? + +""" +Local cache on each worker to reduce remote call. +The links may be staled. +""" +const POT_LINK_CACHE = Dict{PotID, RemoteChannel{Channel{Any}}}() + +""" +Only valid on the driver to keep track of all registered pots. +TODO: use a kv db +""" +const POT_REGISTRY = Dict{PotID, Pot}() +const POT_CHILDREN = Dict{PotID, Set{PotID}}() + +function is_registered(p::Pot) + is_exist = remotecall_wait(1) do + haskey(Oolong.POT_REGISTRY, p.pid) + end + is_exist[] +end + +function register(p::Pot) + remotecall_wait(1) do + Oolong.POT_REGISTRY[p.pid] = p + children = get!(Oolong.POT_CHILDREN, parent(p.pid), Set{PotID}()) + push!(children, p.pid) + end +end + +function unregister(p::PotID) + remotecall_wait(1) do + delete!(Oolong.POT_REGISTRY, p) + delete!(Oolong.POT_CHILDREN, p) + end +end + +function children(p::PotID) + remotecall_wait(1) do + # ??? data race + get!(Oolong.POT_CHILDREN, p, Set{PotID}()) + end +end + +function link(p::PotID, ch::RemoteChannel) + POT_LINK_CACHE[p] = ch + if myid() != 1 + remotecall_wait(1) do + Oolong.POT_LINK_CACHE[p] = ch + end + end +end + +function Base.getindex(p::PotID) + get!(POT_LINK_CACHE, p) do + ch = remotecall_wait(1) do + get(Oolong.POT_LINK_CACHE, p, nothing) + end + if isnothing(ch[]) + boil(p) + else + ch[] + end + end +end + +whereis(p::PotID) = p[].where + +function Base.getindex(p::PotID, ::typeof(!)) + pot = remotecall_wait(1) do + get(Oolong.POT_REGISTRY, p, nothing) + end + if isnothing(pot[]) + throw(PotNotRegisteredError(p)) + else + pot[] + end +end + +""" +For debug only. Only a snapshot is returned. +!!! DO NOT MODIFY THE RESULT DIRECTLY +""" +Base.getindex(p::PotID, ::typeof(*)) = p(_self())[] + +local_boil(p::PotID) = local_boil(p[!]) + +function local_boil(p::Pot) + pid, tea_bag, logger = p.pid, p.tea_bag, p.logger + ch = RemoteChannel() do + Channel(typemax(Int),spawn=true) do ch + task_local_storage(KEY, PotState(pid, current_task())) + with_logger(logger) do + tea = tea_bag() + while true + try + flavor = take!(ch) + process(tea, flavor) + if flavor isa CloseMsg || flavor isa RemoveMsg + break + end + catch err + @debug err + flavor = parent()(err)[] + if msg isa ResumeMsg + process(tea, flavor) + elseif msg isa CloseMsg + process(tea, flavor) + break + elseif msg isa RestartMsg + process(tea, PreRestartMsg()) + tea = tea_bag() + process(tea, PostRestartMsg()) + else + @error "unknown msg received from parent: $exec" + rethrow() + end + finally + end + end + end + end + end + link(pid, ch) + ch +end + +"blocking until a valid channel is established" +boil(p::PotID) = local_scheduler()(p)[!] + +struct CPUInfo + total_threads::Int + allocated_threads::Int + total_memory::Int + free_memory::Int + function CPUInfo() + new( + Sys.CPU_THREADS, + Threads.nthreads(), + convert(Int, Sys.total_memory()), + convert(Int, Sys.free_memory()), + ) + end +end + +struct GPUInfo + name::String + total_memory::Int + free_memory::Int + function GPUInfo() + new( + name(device()), + CUDA.total_memory(), + CUDA.available_memory() + ) + end +end + +struct ResourceInfo + cpu::CPUInfo + gpu::Vector{GPUInfo} +end + +function ResourceInfo() + cpu = CPUInfo() + gpu = [] + if CUDA.functional() + for d in devices() + device!(d) do + push!(gpu, GPUInfo()) + end + end + end + ResourceInfo(cpu, gpu) +end + +Base.convert(::Type{ResourceInfo}, r::ResourceInfo) = ResourceInfo(r.cpu.allocated_threads, length(r.gpu)) + +struct HeartBeat + resource::ResourceInfo + available::ResourceInfo + from::PotID +end + +struct LocalScheduler + pending::Dict{PotID, Future} + peers::Ref{Dict{PotID, ResourceInfo}} + available::Ref{ResourceInfo} + timer::Timer +end + +# TODO: watch exit info + +function LocalScheduler() + pid = self() + req = convert(ResourceInfo, ResourceInfo()) + available = Ref(req) + timer = Timer(1;interval=1) do t + HeartBeat(ResourceInfo(), available[], pid) |> SCHEDULER # !!! non blocking + end + + pending = Dict{PotID, Future}() + peers = Ref(Dict{PotID, ResourceInfo}(pid => req)) + + LocalScheduler(pending, peers, available, timer) +end + +function (s::LocalScheduler)(p::PotID) + pot = p[!] + if pot.require <= s.available[] + res = local_boil(p) + s.available[] -= pot.require + res + else + res = Future() + s.pending[p] = res + res + end +end + +function (s::LocalScheduler)(peers::Dict{PotID, ResourceInfo}) + s.peers[] = peers + for (p, f) in s.pending + pot = p[!] + for (w, r) in peers + if pot.require <= r + # transfer to w + put!(f, w(p)) + delete!(s.pending, p) + break + end + end + end +end + +Base.@kwdef struct Scheduler + workers::Dict{PotID, HeartBeat} = Dict() + pending::Dict{PotID, Future} = Dict() +end + +# ??? throttle +function (s::Scheduler)(h::HeartBeat) + # ??? TTL + s.workers[h.from] = h + + for (p, f) in s.pending + pot = p[!] + if pot.require <= h.available + put!(f, h.from(p)) + end + end + + Dict( + p => h.available + for (p, h) in s.workers + ) |> h.from # !!! non blocking +end + +# pots are all scheduled on workers only +function (s::Scheduler)(p::PotID) + pot = p[!] + for (w, h) in s.workers + if pot.require <= h.available + return w(p) + end + end + res = Future() + s.pending[p] = res + res +end + + diff --git a/src/logging.jl b/src/logging.jl new file mode 100644 index 0000000..259b07c --- /dev/null +++ b/src/logging.jl @@ -0,0 +1,148 @@ +using LoggingExtras +using LokiLogger +using Dates + +# https://github.com/JuliaLang/julia/blob/1b93d53fc4bb59350ada898038ed4de2994cce33/base/logging.jl#L142-L151 +function Base.parse(::Type{LogLevel}, s::String) + if s == string(Logging.BelowMinLevel) Logging.BelowMinLevel + elseif s == string(Logging.Debug) Logging.Debug + elseif s == string(Logging.Info) Logging.Info + elseif s == string(Logging.Warn) Logging.Warn + elseif s == string(Logging.Error) Logging.Error + elseif s == string(Logging.AboveMaxLevel) Logging.AboveMaxLevel + else + m = match(r"LogLevel\((?-?[1-9]\d*)\)", s) + if isnothing(m) + throw(ArgumentError("unknown log level")) + else + Logging.LogLevel(parse(Int, m[:level])) + end + end + +end + +function create_log_transformer(date_format) + function transformer(log) + merge( + log, + ( + datetime = Dates.format(now(), date_format), + from=self(), + myid=myid(), + ) + ) + end +end + +function create_default_fmt(with_color=false, is_expand_stack_trace=false) + function default_fmt(iob, args) + level, message, _module, group, id, file, line, kw = args + color, prefix, suffix = Logging.default_metafmt( + level, _module, group, id, file, line + ) + ignore_fields = (:datetime, :path, :myid, :from) + if with_color + printstyled(iob, "$(kw.datetime) "; color=:light_black) + printstyled(iob, prefix; bold=true, color=color) + printstyled(iob, "[$(kw.from)@$(kw.myid)]"; color=:green) + print(iob, message) + for (k,v) in pairs(kw) + if k ∉ ignore_fields + print(iob, " ") + printstyled(iob, k; color=:yellow) + printstyled(iob, "="; color=:light_black) + print(iob, v) + end + end + !isempty(suffix) && printstyled(iob, " ($suffix)"; color=:light_black) + println(iob) + else + print(iob, "$(kw.datetime) $prefix[$(kw.from)@$(kw.myid)]$message") + for (k,v) in pairs(kw) + if k ∉ ignore_fields + print(iob, " $k=$v") + end + end + !isempty(suffix) && print(iob, " ($suffix)") + println(iob) + end + end +end + +function create_logger(config::Config) + sinks = [] + + if !isnothing(config.logging.loki_logger) + push!(sinks, LokiLogger.Logger(config.logging.loki_logger.url)) + end + + if !isnothing(config.logging.driver_logger) + driver_sinks = [] + console_logger_config = config.logging.driver_logger.console_logger + if !isnothing(console_logger_config) + push!( + driver_sinks, + FormatLogger( + create_default_fmt( + config.color, + console_logger_config.is_expand_stack_trace + ) + ) + ) + end + rotating_logger_config = config.logging.driver_logger.rotating_logger + if !isnothing(rotating_logger_config) + mkpath(rotating_logger_config.path) + push!( + driver_sinks, + DatetimeRotatingFileLogger( + create_default_fmt(), + rotating_logger_config.path, + rotating_logger_config.file_format, + ) + ) + end + if isempty(driver_sinks) + push!(driver_sinks, current_logger()) + end + push!(sinks, DriverLogger(TeeLogger(driver_sinks...))) + end + + if isempty(sinks) + push!(sinks, current_logger()) + end + + TeeLogger( + ( + MinLevelLogger( + TransformerLogger( + create_log_transformer(config.logging.date_format), + s + ), + parse(Logging.LogLevel, config.logging.log_level) + ) + for s in sinks + )... + ) +end + +##### + +Base.@kwdef struct DriverLogger <: AbstractLogger + logger::TeeLogger +end + +Logging.shouldlog(::DriverLogger, args...) = true +Logging.min_enabled_level(::DriverLogger) = Logging.BelowMinLevel +Logging.catch_exceptions(::DriverLogger) = true + +struct LogMsg + args + kw +end + +Logging.handle_message(logger::DriverLogger, args...; kw...) = LogMsg(args, kw) |> LOGGER + +function (L::DriverLogger)(msg::LogMsg) + handle_message(L.logger, msg.args...;msg.kw...) +end diff --git a/src/parameter_server.jl b/src/parameter_server.jl deleted file mode 100644 index aff7237..0000000 --- a/src/parameter_server.jl +++ /dev/null @@ -1,21 +0,0 @@ -struct ParameterServer - params -end - -function (ps::ParameterServer)(gs) - for (p, g) in zip(ps.params, gs) - p .-= g - end -end - -(ps::ParameterServer)() = deepcopy(p.params) - -# Example usage -# ```julia -# ps = @actor ParameterServer([zeros(Float32, 3, 4), zeros(Float32, 3)]) -# for c in clients -# params = ps()[] -# gs = calc_gradients(params) -# ps[gs] -# end -# ``` \ No newline at end of file diff --git a/src/scheduler.jl b/src/scheduler.jl deleted file mode 100644 index 72e288b..0000000 --- a/src/scheduler.jl +++ /dev/null @@ -1,15 +0,0 @@ -#= - -# Design Doc - -- Each node *usually* creates ONE processor. -- Each processor has ONE `LocalScheduler` -- Each `LocalScheduler` sends its available resources to the global `Scheduler` on the driver processor periodically. -- `LocalScheduler` tries to initialize actors on the local processor when available resource meets. Otherwise, send the actor initializer to the global `Scheduler`. -- If there's no resource matches, the actor is put into the `StagingArea` and will be reschedulered periodically. - -Auto scaling: - -- When the whole cluster is under presure, the global `Scheduler` may send request to claim for more resources. -- When most nodes are idle, the global `Scheduler` may ask some nodes to exit and reschedule the actors on it. -=# \ No newline at end of file diff --git a/src/serve.jl b/src/serve.jl deleted file mode 100644 index 92b8316..0000000 --- a/src/serve.jl +++ /dev/null @@ -1,114 +0,0 @@ - -# directly taken from https://github.com/FluxML/Flux.jl/blob/27c4c77dc5abd8e791f4ca4e68a65fc7a91ebcfd/src/utils.jl#L544-L566 -batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i) - -function batch(xs) - data = first(xs) isa AbstractArray ? - similar(first(xs), size(first(xs))..., length(xs)) : - Vector{eltype(xs)}(undef, length(xs)) - for (i, x) in enumerate(xs) - data[batchindex(data, i)...] = x - end - return data -end - -struct ProcessBatchMsg end - -mutable struct BatchStrategy - buffer::Vector{Any} - reqs::Vector{Mailbox} - model::Mailbox - batch_wait_timeout_s::Float64 - max_batch_size::Int - timer::Union{Nothing, Timer} - n_ongoing_batches::Int -end - -""" - BatchStrategy(model;kwargs...) - -# Keyword Arguments - -- `model::Mailbox`, an actor which wraps the model. This actor must accepts - [`RequestMsg`](@ref) as input and reply with a [`ReplyMsg`](@ref) - correspondingly. -- `batch_wait_timeout_s=0.0`, time to wait before handling the next batch. -- `max_batch_size=8, the maximum batch size to handle each time. - -Everytime we processed a batch, we create a timer and wait for at most -`batch_wait_timeout_s` to handle the next batch. If we get `max_batch_size` -requests before reaching `batch_wait_timeout_s`, the timer is reset. If -`batch_wait_timeout_s==0`, we process the available requests immediately. - -!!! warning - The `model` must reply in a non-blocking way (by using [`async_rep`](@ref) or). - Otherwise, there may be deadlock (see test cases if you are interested). -""" -function BatchStrategy( - model; - batch_wait_timeout_s=0.0, - max_batch_size=8, -) - mb = self() - if batch_wait_timeout_s == 0. - timer = nothing - else - timer = Timer(batch_wait_timeout_s) do timer - mb[ProcessBatchMsg()] - end - end - BatchStrategy( - Vector(), - Vector{Mailbox}(), - model, - batch_wait_timeout_s, - max_batch_size, - nothing, - 0 - ) -end - -function reset_timer!(s::BatchStrategy) - isnothing(s.timer) || close(s.timer) - mb = self() - s.timer = Timer(s.batch_wait_timeout_s) do t - mb[ProcessBatchMsg()] - end -end - -function handle(s::BatchStrategy, req::RequestMsg) - push!(s.buffer, req.msg) - push!(s.reqs, req.from) - - if length(s.buffer) == 1 - if s.batch_wait_timeout_s == 0 - if s.n_ongoing_batches == 0 - s(ProcessBatchMsg()) - end - else - reset_timer!(s) # set a timer to insert a ProcessBatchMsg to self() - end - elseif length(s.buffer) == s.max_batch_size - s(ProcessBatchMsg()) - end -end - -function (s::BatchStrategy)(::ProcessBatchMsg) - if !isempty(s.buffer) - @info "???" s.buffer - data = length(s.buffer) == 1 ? reshape(s.buffer[1], size(s.buffer[1])..., 1) : batch(s.buffer) - empty!(s.buffer) - s.n_ongoing_batches += 1 - req(s.model, data) - end -end - -function (s::BatchStrategy)(msg::ReplyMsg) - for res in msg.msg - rep(popfirst!(s.reqs), res) - end - s.n_ongoing_batches -= 1 - s(ProcessBatchMsg()) -end - -# X** \ No newline at end of file diff --git a/src/start.jl b/src/start.jl new file mode 100644 index 0000000..a7ff638 --- /dev/null +++ b/src/start.jl @@ -0,0 +1,70 @@ +struct Root + function Root() + local_boil(@pot DefaultLogger() name=LOGGER logger=current_logger()) + local_boil(@pot Scheduler() name=SCHEDULER) + new() + end +end + +function banner(io::IO=stdout;color=true) + c = Base.text_colors + tx = c[:normal] # text + d1 = c[:bold] * c[:blue] # first dot + d2 = c[:bold] * c[:red] # second dot + d3 = c[:bold] * c[:green] # third dot + d4 = c[:bold] * c[:magenta] # fourth dot + + if color + print(io, + """ + ____ _ | > 是非成败转头空 + / $(d1)__$(tx) \\ | | | > Success or failure, + | $(d1)| |$(tx) | ___ | | ___ _ __ __ _ | > right or wrong, + | $(d1)| |$(tx) |/ $(d2)_$(tx) \\| |/ $(d3)_$(tx) \\| '_ \\ / $(d4)_$(tx)` | | > all turn out vain. + | $(d1)|__|$(tx) | $(d2)(_)$(tx) | | $(d3)(_)$(tx) | | | | $(d4)(_)$(tx) | | + \\____/ \\___/|_|\\___/|_| |_|\\__, | | The Immortals by the River + __/ | | -- Yang Shen + |___/ | (Translated by Xu Yuanchong) + """) + else + print(io, + """ + ____ _ | > 是非成败转头空 + / __ \\ | | | > Success or failure, + | | | | ___ | | ___ _ __ __ _ | > right or wrong, + | | | |/ _ \\| |/ _ \\| '_ \\ / _` | | > all turn out vain. + | |__| | (_) | | (_) | | | | (_) | | + \\____/ \\___/|_|\\___/|_| |_|\\__, | | The Immortals by the River + __/ | | -- Yang Shen + |___/ | (Translated by Xu Yuanchong) + """) + end +end + +function start(config_file::String="Oolong.yaml";kw...) + config = nothing + if isfile(config_file) + @info "Found $config_file. Loading configs..." + config = Configurations.from_dict(Config, YAML.load_file(config_file; dicttype=Dict{String, Any});kw...) + else + @info "$config_file not found. Using default configs." + config = Config(;kw...) + end + start(config) +end + +function start(config::Config) + config.banner && banner(color=config.color) + + @info "$(@__MODULE__) starting..." + if myid() == 1 + local_boil(@pot Root() name=ROOT logger=current_logger()) + end + + if myid() in workers() + local_boil(@pot LocalScheduler() name=local_scheduler()) + end +end + +function stop() +end diff --git a/test/runtests.jl b/test/runtests.jl index 961380e..2959e90 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,5 +4,4 @@ using Base.Threads @testset "Oolong.jl" begin include("core.jl") - include("serve.jl") end diff --git a/test/serve.jl b/test/serve.jl deleted file mode 100644 index 60ee35c..0000000 --- a/test/serve.jl +++ /dev/null @@ -1,52 +0,0 @@ -@testset "serve" begin - # q: request - # p: reply - # |: batch_wait_timeout_s reached - - @testset "batch_wait_timeout_s = 0" begin - # case 1: qpqpqp - struct DummyModel end - - function OL.handle(m::DummyModel, msg::OL.RequestMsg) - xs = msg.msg - @debug "value received in model" xs typeof(xs) - sleep(0.1) - res = vec(sum(xs; dims=1)) .+ 1 - OL.async_rep(msg.from, res) - end - model = @actor DummyModel() name="model" - - server = @actor OL.BatchStrategy(model;max_batch_size=2) name="server" - - worker = OL.Mailbox() - t = @elapsed for i in 1:10 - put!(server,OL.RequestMsg([i], worker)) - sleep(0.11) - end - - for i in 1:10 - msg = take!(worker) - @test msg.msg == i+1 - end - - # case 2: qqqqqpqqqqp - t = @elapsed begin - for i in 1:5 - put!(server,OL.RequestMsg([i], worker)) - end - for i in 1:5 - @test take!(worker).msg == i+1 - end - end - - t = @elapsed begin - for i in 1:64 - put!(server,OL.RequestMsg([i], worker)) - end - for i in 1:64 - @test take!(worker).msg == i+1 - end - end - end - -end \ No newline at end of file