From a2c1a4aea3f47497756ad9de6acf0cda720e9eda Mon Sep 17 00:00:00 2001 From: Jan Mas Rovira Date: Fri, 7 Jun 2024 11:24:40 +0200 Subject: [PATCH] Improve parallel template (#2809) - Closes #2806 Now we properly wait for the log and worker threads to finish instead of keeping them alive until the main thread dies. Also, because we are now using [`replicateConcurrently_`](https://hackage.haskell.org/package/effectful-2.3.0.0/docs/Effectful-Concurrent-Async.html#v:replicateConcurrently) from `async`, any exception in a worker thread should be properly propagated. --- src/Parallel/ParallelTemplate.hs | 177 ++++++++++++++++--------------- 1 file changed, 91 insertions(+), 86 deletions(-) diff --git a/src/Parallel/ParallelTemplate.hs b/src/Parallel/ParallelTemplate.hs index 9d0f6b33c4..0f11985aee 100644 --- a/src/Parallel/ParallelTemplate.hs +++ b/src/Parallel/ParallelTemplate.hs @@ -15,14 +15,12 @@ module Parallel.ParallelTemplate compileArgsNumWorkers, compileArgsCompileNode, compileArgsPreProcess, - compilationError, compile, ) where import Control.Concurrent (ThreadId) import Control.Concurrent.STM.TVar (stateTVar) -import Control.Exception qualified as GHC import Data.HashMap.Strict qualified as HashMap import Data.HashSet qualified as HashSet import Effectful.Concurrent @@ -49,7 +47,6 @@ data CompilationState nodeId compiledProof = CompilationState _compilationPending :: HashMap nodeId (HashSet nodeId), _compilationStartedNum :: Natural, _compilationFinishedNum :: Natural, - _compilationError :: Maybe JuvixError, _compilationTotalNum :: Natural } @@ -67,14 +64,31 @@ newtype CompileQueue nodeId = CompileQueue { _compileQueue :: TBQueue nodeId } +data LogQueueItem + = LogQueueItem LogItem + | -- | no more log items will be handled after this + LogQueueClose + newtype Logs = Logs - { _logQueue :: TQueue LogItem + { _logQueue :: TQueue LogQueueItem } newtype NodesIndex nodeId node = NodesIndex { _nodesIndex :: HashMap nodeId node } +data Task nodeId node = Task + { _taskNum :: Natural, + _taskTotal :: Natural, + _taskNodeId :: nodeId, + _taskNode :: node + } + +data Finished + = -- | All modules have started compilation. They might still be compiling + FinishedNoPending + | FinishedPending + makeLenses ''Logs makeLenses ''NodesIndex makeLenses ''CompileQueue @@ -89,16 +103,11 @@ instance (Show nodeId, Pretty nodeId) => Pretty (Dependencies nodeId) where | (from, deps) <- HashMap.toList (d ^. dependenciesTable) ] -data Finished - = FinishedOk - | FinishedError JuvixError - | FinishedNot - compilationStateFinished :: CompilationState nodeId compileProof -> Finished compilationStateFinished CompilationState {..} - | Just err <- _compilationError = FinishedError err - | _compilationFinishedNum == _compilationTotalNum = FinishedOk - | otherwise = FinishedNot + | _compilationStartedNum == _compilationTotalNum = FinishedNoPending + | _compilationStartedNum > _compilationTotalNum = impossible + | otherwise = FinishedPending addCompiledModule :: forall nodeId proof. @@ -156,7 +165,6 @@ compile args@CompileArgs {..} = do { _compilationStartedNum = 0, _compilationFinishedNum = 0, _compilationTotalNum = numMods, - _compilationError = Nothing, _compilationPending = deps ^. dependenciesTable, _compilationState = mempty } @@ -169,58 +177,75 @@ compile args@CompileArgs {..} = do . runReader deps . crashOnError $ do - let newThread :: - forall r' a. - (Members '[Concurrent] r') => - Sem r' a -> - Sem r' () - newThread m = void . forkFinally m $ \case - Left err -> GHC.throw err - Right {} -> return () - withAsync handleLogs $ \_logHandler -> do - let useAsync = False - if - | useAsync -> - replicateConcurrently_ _compileArgsNumWorkers $ - lookForWork @nodeId @node @compileProof - | otherwise -> - replicateM_ _compileArgsNumWorkers - . newThread - $ lookForWork @nodeId @node @compileProof - waitForWorkers @nodeId @compileProof + withAsync handleLogs $ \logHandler -> do + replicateConcurrently_ _compileArgsNumWorkers $ + lookForWork @nodeId @node @compileProof + wait logHandler (^. compilationState) <$> readTVarIO varCompilationState handleLogs :: (Members '[ProgressLog, Concurrent, Reader Logs] r) => Sem r () handleLogs = do x <- asks (^. logQueue) >>= atomically . readTQueue - progressLog x - handleLogs + case x of + LogQueueClose -> return () + LogQueueItem l -> do + progressLog l + handleLogs -waitForWorkers :: - forall nodeId compileProof r. - ( Members +getTask :: + forall nodeId (node :: GHCType) compileProof (s :: [Effect]) r. + ( Hashable nodeId, + Members '[ Concurrent, Reader (TVar (CompilationState nodeId compileProof)), - Error JuvixError, + Reader (CompileArgs s nodeId node compileProof), + Reader (NodesIndex nodeId node), + Reader (CompileQueue nodeId), Reader Logs ] r ) => - Sem r () -waitForWorkers = do - Logs logs <- ask + Sem r (Maybe (Task nodeId node)) +getTask = do + stVar <- ask @(TVar (CompilationState nodeId compileProof)) + qq <- asks (^. compileQueue) cstVar <- ask @(TVar (CompilationState nodeId compileProof)) - (finished, noMoreLogs) <- atomically $ do + idx <- ask @(NodesIndex nodeId node) + logs <- ask + args <- ask @(CompileArgs s nodeId node compileProof) + tid <- myThreadId + atomically $ do finished <- compilationStateFinished <$> readTVar cstVar - noMoreLogs <- isEmptyTQueue logs - return (finished, noMoreLogs) - let waitMore = waitForWorkers @nodeId @compileProof - case finished of - FinishedError err - | noMoreLogs -> throw err - | otherwise -> waitMore - FinishedNot -> waitMore - FinishedOk -> unless noMoreLogs waitMore + case finished of + FinishedNoPending -> return Nothing + FinishedPending -> do + nextModuleId :: nodeId <- readTBQueue qq + let n :: node = + run + . runReader idx + $ getNode nextModuleId + compSt <- readTVar stVar + modifyTVar stVar (over compilationStartedNum succ) + let num = succ (compSt ^. compilationStartedNum) + total = compSt ^. compilationTotalNum + name = annotate (AnnKind KNameTopModule) (pretty ((args ^. compileArgsNodeName) n)) + progress :: Doc CodeAnn = + kwBracketL + <> annotate AnnLiteralInteger (pretty num) + <+> kwOf + <+> annotate AnnLiteralInteger (pretty total) <> kwBracketR <> " " + kwCompiling = annotate AnnKeyword "Compiling" + isLast = num == total + logMsg tid logs (progress <> kwCompiling <> " " <> name) + when isLast (logClose logs) + return $ + Just + Task + { _taskNum = num, + _taskTotal = total, + _taskNodeId = nextModuleId, + _taskNode = n + } lookForWork :: forall nodeId node compileProof (s :: [Effect]) r. @@ -241,30 +266,9 @@ lookForWork :: ) => Sem r () lookForWork = do - qq <- asks (^. compileQueue) - stVar <- ask @(TVar (CompilationState nodeId compileProof)) - logs <- ask - args <- ask @(CompileArgs s nodeId node compileProof) - idx <- ask @(NodesIndex nodeId node) - tid <- myThreadId - nextModule <- atomically $ do - nextModule :: nodeId <- readTBQueue qq - let n :: node = run . runReader idx $ getNode nextModule - name = annotate (AnnKind KNameTopModule) (pretty ((args ^. compileArgsNodeName) n)) - compSt <- readTVar stVar - modifyTVar stVar (over compilationStartedNum succ) - let num = compSt ^. compilationStartedNum - total = compSt ^. compilationTotalNum - progress :: Doc CodeAnn = - kwBracketL - <> annotate AnnLiteralInteger (pretty (succ num)) - <+> kwOf - <+> annotate AnnLiteralInteger (pretty total) <> kwBracketR <> " " - kwCompiling = annotate AnnKeyword "Compiling" - logMsg tid logs (progress <> kwCompiling <> " " <> name) - return nextModule - compileNode @s @nodeId @node @compileProof nextModule - lookForWork @nodeId @node @compileProof @s @r + whenJustM (getTask @nodeId @node @compileProof @s) $ \Task {..} -> do + compileNode @s @nodeId @node @compileProof _taskNodeId + lookForWork @nodeId @node @compileProof @s @r getNode :: forall nodeId node r. @@ -295,13 +299,8 @@ compileNode :: compileNode nodId = do m :: node <- getNode nodId compileFun <- asks @(CompileArgs s nodeId node compileProof) (^. compileArgsCompileNode) - st :: TVar (CompilationState nodeId compileProof) <- ask - result :: Either (CallStack, JuvixError) compileProof <- - inject $ - tryError @JuvixError (compileFun m) - case result of - Left (_, err) -> atomically (modifyTVar st (set compilationError (Just err))) - Right proof -> registerCompiledModule @nodeId @node @s @compileProof nodId proof + proof :: compileProof <- inject (compileFun m) + registerCompiledModule @nodeId @node @s @compileProof nodId proof registerCompiledModule :: forall nodeId node s compileProof r. @@ -328,11 +327,17 @@ registerCompiledModule m proof = do toQueue <- stateTVar mutSt (swap . addCompiledModule deps m proof) forM_ toQueue (writeTBQueue qq) +logClose :: Logs -> STM () +logClose (Logs q) = do + STM.writeTQueue q LogQueueClose + logMsg :: ThreadId -> Logs -> Doc CodeAnn -> STM () logMsg tid (Logs q) msg = do STM.writeTQueue q - LogItem - { _logItemMessage = msg, - _logItemThreadId = tid - } + ( LogQueueItem + LogItem + { _logItemMessage = msg, + _logItemThreadId = tid + } + )