From 76355e649fcf70c485c40471f92a7d5669aa25cd Mon Sep 17 00:00:00 2001
From: johnmaxrin <johnmaxrin@gmail.com>
Date: Mon, 1 Apr 2024 22:53:14 +0530
Subject: [PATCH] Add Task Monitor

Co-authored-by: Julian Samaroo <jpsamaroo@gmail.com>
---
 Project.toml        | 25 +++++++++++++-----------
 src/Dagger.jl       |  1 +
 src/queue.jl        |  4 ++--
 src/task_monitor.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++
 src/thunk.jl        |  2 +-
 5 files changed, 65 insertions(+), 14 deletions(-)
 create mode 100644 src/task_monitor.jl

diff --git a/Project.toml b/Project.toml
index f59fb278c..e768514f3 100644
--- a/Project.toml
+++ b/Project.toml
@@ -12,6 +12,7 @@ MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
 OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
 PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
 Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
+ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
@@ -20,9 +21,21 @@ SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
 SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
+TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
 TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
 UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
 
+[weakdeps]
+Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
+DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
+GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
+Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
+
+[extensions]
+GraphVizExt = "GraphViz"
+GraphVizSimpleExt = "Colors"
+PlotsExt = ["DataFrames", "Plots"]
+
 [compat]
 DataStructures = "0.18"
 Graphs = "1"
@@ -34,22 +47,12 @@ Requires = "1"
 ScopedValues = "1.1"
 Statistics = "1"
 StatsBase = "0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
+TaskLocalValues = "0.1"
 TimespanLogging = "0.1"
 julia = "1.8"
 
-[extensions]
-GraphVizSimpleExt = "Colors"
-GraphVizExt = "GraphViz"
-PlotsExt = ["DataFrames", "Plots"]
-
 [extras]
 Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
 DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
 GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-
-[weakdeps]
-Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
-DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
-GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
-Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
diff --git a/src/Dagger.jl b/src/Dagger.jl
index 4dc9174c2..d726062a7 100644
--- a/src/Dagger.jl
+++ b/src/Dagger.jl
@@ -41,6 +41,7 @@ include("scopes.jl")
 include("utils/scopes.jl")
 include("eager_thunk.jl")
 include("queue.jl")
+include("task_monitor.jl")
 include("thunk.jl")
 include("submission.jl")
 include("chunks.jl")
diff --git a/src/queue.jl b/src/queue.jl
index 8d7cc7f6e..761c5a454 100644
--- a/src/queue.jl
+++ b/src/queue.jl
@@ -15,9 +15,9 @@ enqueue!(::EagerTaskQueue, specs::Vector{Pair{EagerTaskSpec,EagerThunk}}) =
     eager_launch!(specs)
 
 enqueue!(spec::Pair{EagerTaskSpec,EagerThunk}) =
-    enqueue!(get_options(:task_queue, EagerTaskQueue()), spec)
+    enqueue!(get_options(:task_queue, MONITOR_QUEUE[]), spec)
 enqueue!(specs::Vector{Pair{EagerTaskSpec,EagerThunk}}) =
-    enqueue!(get_options(:task_queue, EagerTaskQueue()), specs)
+    enqueue!(get_options(:task_queue, MONITOR_QUEUE[]), specs)
 
 struct LazyTaskQueue <: AbstractTaskQueue
     tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}
diff --git a/src/task_monitor.jl b/src/task_monitor.jl
new file mode 100644
index 000000000..8aa485481
--- /dev/null
+++ b/src/task_monitor.jl
@@ -0,0 +1,47 @@
+using ProgressMeter
+using TaskLocalValues
+
+struct MonitorTaskQueue <: AbstractTaskQueue
+    running_tasks::Vector{WeakRef}
+    MonitorTaskQueue() = new(WeakRef[])
+end
+function enqueue!(queue::MonitorTaskQueue, spec::Pair{EagerTaskSpec,EagerThunk})
+    push!(queue.running_tasks, WeakRef(spec[2]))
+    upper = get_options(:task_queue, EagerTaskQueue())
+    enqueue!(upper, spec)
+end
+
+function enqueue!(queue::MonitorTaskQueue, specs::Vector{Pair{EagerTaskSpec,EagerThunk}})
+    for (_, task) in specs
+        push!(queue.running_tasks, WeakRef(task))
+    end
+    upper = get_options(:task_queue, EagerTaskQueue())
+    enqueue!(upper, specs)
+end
+
+const MONITOR_QUEUE = TaskLocalValue{MonitorTaskQueue}(MonitorTaskQueue)
+
+"Monitors and displays the progress of any still-executing tasks."
+function monitor()
+    queue = MONITOR_QUEUE[]
+    running_tasks = queue.running_tasks
+    isempty(running_tasks) && return
+
+    ntasks = length(running_tasks)
+    meter = Progress(ntasks;
+                     desc="Waiting for $ntasks tasks...",
+                     dt=0.01, showspeed=true)
+    while !isempty(running_tasks)
+        for (i, task_weak) in reverse(collect(enumerate(running_tasks)))
+            task = task_weak.value
+            if task === nothing || isready(task)
+                next!(meter)
+                deleteat!(running_tasks, i)
+            end
+        end
+        sleep(0.01)
+    end
+    finish!(meter)
+
+    return
+end
diff --git a/src/thunk.jl b/src/thunk.jl
index bf92744ab..3c342da54 100644
--- a/src/thunk.jl
+++ b/src/thunk.jl
@@ -309,7 +309,7 @@ function spawn(f, args...; kwargs...)
     args_kwargs = args_kwargs_to_pairs(args, kwargs)
 
     # Get task queue, and don't let it propagate
-    task_queue = get_options(:task_queue, EagerTaskQueue())
+    task_queue = get_options(:task_queue, MONITOR_QUEUE[])
     options = NamedTuple(filter(opt->opt[1] != :task_queue, Base.pairs(options)))
     propagates = filter(prop->prop != :task_queue, propagates)
     options = merge(options, (;propagates))