Skip to content

Commit

Permalink
Improve parallel template (#2809)
Browse files Browse the repository at this point in the history
- Closes #2806 

Now we properly wait for the log and worker threads to finish instead of
keeping them alive until the main thread dies.

Also, because we are now using
[`replicateConcurrently_`](https://hackage.haskell.org/package/effectful-2.3.0.0/docs/Effectful-Concurrent-Async.html#v:replicateConcurrently)
from `async`, any exception in a worker thread should be properly
propagated.
  • Loading branch information
janmasrovira authored Jun 7, 2024
1 parent a4f5515 commit a2c1a4a
Showing 1 changed file with 91 additions and 86 deletions.
177 changes: 91 additions & 86 deletions src/Parallel/ParallelTemplate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ module Parallel.ParallelTemplate
compileArgsNumWorkers,
compileArgsCompileNode,
compileArgsPreProcess,
compilationError,
compile,
)
where

import Control.Concurrent (ThreadId)
import Control.Concurrent.STM.TVar (stateTVar)
import Control.Exception qualified as GHC
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Effectful.Concurrent
Expand All @@ -49,7 +47,6 @@ data CompilationState nodeId compiledProof = CompilationState
_compilationPending :: HashMap nodeId (HashSet nodeId),
_compilationStartedNum :: Natural,
_compilationFinishedNum :: Natural,
_compilationError :: Maybe JuvixError,
_compilationTotalNum :: Natural
}

Expand All @@ -67,14 +64,31 @@ newtype CompileQueue nodeId = CompileQueue
{ _compileQueue :: TBQueue nodeId
}

data LogQueueItem
= LogQueueItem LogItem
| -- | no more log items will be handled after this
LogQueueClose

newtype Logs = Logs
{ _logQueue :: TQueue LogItem
{ _logQueue :: TQueue LogQueueItem
}

newtype NodesIndex nodeId node = NodesIndex
{ _nodesIndex :: HashMap nodeId node
}

data Task nodeId node = Task
{ _taskNum :: Natural,
_taskTotal :: Natural,
_taskNodeId :: nodeId,
_taskNode :: node
}

data Finished
= -- | All modules have started compilation. They might still be compiling
FinishedNoPending
| FinishedPending

makeLenses ''Logs
makeLenses ''NodesIndex
makeLenses ''CompileQueue
Expand All @@ -89,16 +103,11 @@ instance (Show nodeId, Pretty nodeId) => Pretty (Dependencies nodeId) where
| (from, deps) <- HashMap.toList (d ^. dependenciesTable)
]

data Finished
= FinishedOk
| FinishedError JuvixError
| FinishedNot

compilationStateFinished :: CompilationState nodeId compileProof -> Finished
compilationStateFinished CompilationState {..}
| Just err <- _compilationError = FinishedError err
| _compilationFinishedNum == _compilationTotalNum = FinishedOk
| otherwise = FinishedNot
| _compilationStartedNum == _compilationTotalNum = FinishedNoPending
| _compilationStartedNum > _compilationTotalNum = impossible
| otherwise = FinishedPending

addCompiledModule ::
forall nodeId proof.
Expand Down Expand Up @@ -156,7 +165,6 @@ compile args@CompileArgs {..} = do
{ _compilationStartedNum = 0,
_compilationFinishedNum = 0,
_compilationTotalNum = numMods,
_compilationError = Nothing,
_compilationPending = deps ^. dependenciesTable,
_compilationState = mempty
}
Expand All @@ -169,58 +177,75 @@ compile args@CompileArgs {..} = do
. runReader deps
. crashOnError
$ do
let newThread ::
forall r' a.
(Members '[Concurrent] r') =>
Sem r' a ->
Sem r' ()
newThread m = void . forkFinally m $ \case
Left err -> GHC.throw err
Right {} -> return ()
withAsync handleLogs $ \_logHandler -> do
let useAsync = False
if
| useAsync ->
replicateConcurrently_ _compileArgsNumWorkers $
lookForWork @nodeId @node @compileProof
| otherwise ->
replicateM_ _compileArgsNumWorkers
. newThread
$ lookForWork @nodeId @node @compileProof
waitForWorkers @nodeId @compileProof
withAsync handleLogs $ \logHandler -> do
replicateConcurrently_ _compileArgsNumWorkers $
lookForWork @nodeId @node @compileProof
wait logHandler
(^. compilationState) <$> readTVarIO varCompilationState

handleLogs :: (Members '[ProgressLog, Concurrent, Reader Logs] r) => Sem r ()
handleLogs = do
x <- asks (^. logQueue) >>= atomically . readTQueue
progressLog x
handleLogs
case x of
LogQueueClose -> return ()
LogQueueItem l -> do
progressLog l
handleLogs

waitForWorkers ::
forall nodeId compileProof r.
( Members
getTask ::
forall nodeId (node :: GHCType) compileProof (s :: [Effect]) r.
( Hashable nodeId,
Members
'[ Concurrent,
Reader (TVar (CompilationState nodeId compileProof)),
Error JuvixError,
Reader (CompileArgs s nodeId node compileProof),
Reader (NodesIndex nodeId node),
Reader (CompileQueue nodeId),
Reader Logs
]
r
) =>
Sem r ()
waitForWorkers = do
Logs logs <- ask
Sem r (Maybe (Task nodeId node))
getTask = do
stVar <- ask @(TVar (CompilationState nodeId compileProof))
qq <- asks (^. compileQueue)
cstVar <- ask @(TVar (CompilationState nodeId compileProof))
(finished, noMoreLogs) <- atomically $ do
idx <- ask @(NodesIndex nodeId node)
logs <- ask
args <- ask @(CompileArgs s nodeId node compileProof)
tid <- myThreadId
atomically $ do
finished <- compilationStateFinished <$> readTVar cstVar
noMoreLogs <- isEmptyTQueue logs
return (finished, noMoreLogs)
let waitMore = waitForWorkers @nodeId @compileProof
case finished of
FinishedError err
| noMoreLogs -> throw err
| otherwise -> waitMore
FinishedNot -> waitMore
FinishedOk -> unless noMoreLogs waitMore
case finished of
FinishedNoPending -> return Nothing
FinishedPending -> do
nextModuleId :: nodeId <- readTBQueue qq
let n :: node =
run
. runReader idx
$ getNode nextModuleId
compSt <- readTVar stVar
modifyTVar stVar (over compilationStartedNum succ)
let num = succ (compSt ^. compilationStartedNum)
total = compSt ^. compilationTotalNum
name = annotate (AnnKind KNameTopModule) (pretty ((args ^. compileArgsNodeName) n))
progress :: Doc CodeAnn =
kwBracketL
<> annotate AnnLiteralInteger (pretty num)
<+> kwOf
<+> annotate AnnLiteralInteger (pretty total) <> kwBracketR <> " "
kwCompiling = annotate AnnKeyword "Compiling"
isLast = num == total
logMsg tid logs (progress <> kwCompiling <> " " <> name)
when isLast (logClose logs)
return $
Just
Task
{ _taskNum = num,
_taskTotal = total,
_taskNodeId = nextModuleId,
_taskNode = n
}

lookForWork ::
forall nodeId node compileProof (s :: [Effect]) r.
Expand All @@ -241,30 +266,9 @@ lookForWork ::
) =>
Sem r ()
lookForWork = do
qq <- asks (^. compileQueue)
stVar <- ask @(TVar (CompilationState nodeId compileProof))
logs <- ask
args <- ask @(CompileArgs s nodeId node compileProof)
idx <- ask @(NodesIndex nodeId node)
tid <- myThreadId
nextModule <- atomically $ do
nextModule :: nodeId <- readTBQueue qq
let n :: node = run . runReader idx $ getNode nextModule
name = annotate (AnnKind KNameTopModule) (pretty ((args ^. compileArgsNodeName) n))
compSt <- readTVar stVar
modifyTVar stVar (over compilationStartedNum succ)
let num = compSt ^. compilationStartedNum
total = compSt ^. compilationTotalNum
progress :: Doc CodeAnn =
kwBracketL
<> annotate AnnLiteralInteger (pretty (succ num))
<+> kwOf
<+> annotate AnnLiteralInteger (pretty total) <> kwBracketR <> " "
kwCompiling = annotate AnnKeyword "Compiling"
logMsg tid logs (progress <> kwCompiling <> " " <> name)
return nextModule
compileNode @s @nodeId @node @compileProof nextModule
lookForWork @nodeId @node @compileProof @s @r
whenJustM (getTask @nodeId @node @compileProof @s) $ \Task {..} -> do
compileNode @s @nodeId @node @compileProof _taskNodeId
lookForWork @nodeId @node @compileProof @s @r

getNode ::
forall nodeId node r.
Expand Down Expand Up @@ -295,13 +299,8 @@ compileNode ::
compileNode nodId = do
m :: node <- getNode nodId
compileFun <- asks @(CompileArgs s nodeId node compileProof) (^. compileArgsCompileNode)
st :: TVar (CompilationState nodeId compileProof) <- ask
result :: Either (CallStack, JuvixError) compileProof <-
inject $
tryError @JuvixError (compileFun m)
case result of
Left (_, err) -> atomically (modifyTVar st (set compilationError (Just err)))
Right proof -> registerCompiledModule @nodeId @node @s @compileProof nodId proof
proof :: compileProof <- inject (compileFun m)
registerCompiledModule @nodeId @node @s @compileProof nodId proof

registerCompiledModule ::
forall nodeId node s compileProof r.
Expand All @@ -328,11 +327,17 @@ registerCompiledModule m proof = do
toQueue <- stateTVar mutSt (swap . addCompiledModule deps m proof)
forM_ toQueue (writeTBQueue qq)

logClose :: Logs -> STM ()
logClose (Logs q) = do
STM.writeTQueue q LogQueueClose

logMsg :: ThreadId -> Logs -> Doc CodeAnn -> STM ()
logMsg tid (Logs q) msg = do
STM.writeTQueue
q
LogItem
{ _logItemMessage = msg,
_logItemThreadId = tid
}
( LogQueueItem
LogItem
{ _logItemMessage = msg,
_logItemThreadId = tid
}
)

0 comments on commit a2c1a4a

Please sign in to comment.