Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference buffer #290

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
74 changes: 42 additions & 32 deletions rhine-bayes/app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ import Text.Printf (printf)
-- transformers
import Control.Monad.Trans.Class

-- time
import Data.Time (addUTCTime, getCurrentTime)

-- mmorph
import Control.Monad.Morph

Expand Down Expand Up @@ -113,7 +110,7 @@ initialTemperature :: Temperature
initialTemperature = 7

-- | We assume the user changes the temperature randomly every 3 seconds.
temperatureProcess :: (MonadDistribution m, Diff td ~ Double) => BehaviourF m td () Temperature
temperatureProcess :: (MonadDistribution m, Diff td ~ Double) => Behaviour m td Temperature
temperatureProcess =
-- Draw events from a Poisson process with a rate of one event per 3 seconds
poissonHomogeneous 3
Expand Down Expand Up @@ -171,7 +168,7 @@ emptyResult =

-- | The number of particles used in the filter. Change according to available computing power.
nParticles :: Int
nParticles = 100
nParticles = 200

-- * Visualization

Expand All @@ -190,6 +187,7 @@ visualisation :: (Diff td ~ Double) => BehaviourF App td Result ()
visualisation = proc Result {temperature, measured, latent, particlesPosition, particlesTemperature} -> do
constMCl clearIO -< ()
time <- sinceInitS -< ()
dt <- sinceLastS -< ()
arrMCl paintIO
-<
toThermometer $
Expand All @@ -201,14 +199,15 @@ visualisation = proc Result {temperature, measured, latent, particlesPosition, p
[ printf "Temperature: %.2f" temperature
, printf "Particles: %i" $ length particlesPosition
, printf "Time: %.1f" time
, printf "FPS: %.1f" $ 1 / dt
]
return $ translate 0 ((-150) * n) $ text message
, color red $ rectangleUpperSolid thermometerWidth $ double2Float temperature * thermometerScale
]
drawBall -< (measured, 0.3, red)
drawBall -< (latent, 0.3, green)
drawParticles -< particlesPosition
drawParticlesTemperature -< particlesTemperature
drawParticles -< take 100 particlesPosition
drawParticlesTemperature -< take 100 particlesTemperature

-- ** Parameters for the temperature display

Expand Down Expand Up @@ -269,6 +268,7 @@ mains =
[ ("single rate", mainSingleRate)
, ("single rate, parameter collapse", mainSingleRateCollapse)
, ("multi rate, temperature process", mainMultiRate)
, ("multi rate, inference buffer", mainMultiRateInferenceBuffer)
]

main :: IO ()
Expand All @@ -279,15 +279,8 @@ main = do

-- ** Single-rate : One simulation step = one inference step = one display step

-- | Rescale to the 'Double' time domain
type GlossClock = RescaledClock GlossSimClockIO Double

glossClock :: GlossClock
glossClock =
RescaledClock
{ unscaledClock = GlossSimClockIO
, rescale = float2Double
}
glossClockSingleRate :: GlossClockUTC SamplerIO GlossSimClockIO
glossClockSingleRate = glossClockUTC GlossSimClockIO

-- *** Poor attempt at temperature inference: Particle collapse

Expand Down Expand Up @@ -328,11 +321,12 @@ mainClSFCollapse = proc () -> do
output <- filteredCollapse -< initialTemperature
visualisation -< output

mainSingleRateCollapse :: IO ()
mainSingleRateCollapse =
void $
sampleIO $
launchInGlossThread glossSettings $
reactimateCl glossClock mainClSFCollapse
reactimateCl glossClockSingleRate mainClSFCollapse

-- *** Infer temperature with a stochastic process

Expand All @@ -359,34 +353,23 @@ mainClSF = proc () -> do
output <- filtered -< initialTemperature
visualisation -< output

mainSingleRate :: IO ()
mainSingleRate =
void $
sampleIO $
launchInGlossThread glossSettings $
reactimateCl glossClock mainClSF
reactimateCl glossClockSingleRate mainClSF

-- ** Multi-rate: Simulation, inference, display at different rates

-- | Rescale the gloss clocks so they will be compatible with real 'UTCTime' (needed for compatibility with 'Millisecond')
type GlossClockUTC cl = RescaledClockS (GlossConcT IO) cl UTCTime (Tag cl)

glossClockUTC :: (Real (Time cl)) => cl -> GlossClockUTC cl
glossClockUTC cl =
RescaledClockS
{ unscaledClockS = cl
, rescaleS = const $ do
now <- liftIO getCurrentTime
return (arr $ \(timePassed, event) -> (addUTCTime (realToFrac timePassed) now, event), now)
}

{- | The part of the program which simulates latent position and sensor,
running 100 times a second.
-}
modelRhine :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT (Millisecond 100)) Temperature (Temperature, (Sensor, Pos))
modelRhine = hoistClSF sampleIOGloss (clId &&& genModelWithoutTemperature) @@ liftClock waitClock

-- | The user can change the temperature by pressing the up and down arrow keys.
userTemperature :: ClSF (GlossConcT IO) (GlossClockUTC GlossEventClockIO) () Temperature
userTemperature :: ClSF (GlossConcT IO) (GlossClockUTC IO GlossEventClockIO) () Temperature
userTemperature = tagS >>> arr (selector >>> fmap Product) >>> mappendS >>> arr (fmap getProduct >>> fromMaybe 1 >>> (* initialTemperature))
where
selector (EventKey (SpecialKey KeyUp) Down _ _) = Just 1.2
Expand All @@ -413,7 +396,7 @@ inference = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy
}

-- | Visualize the current 'Result' at a rate controlled by the @gloss@ backend, usually 30 FPS.
visualisationRhine :: Rhine (GlossConcT IO) (GlossClockUTC GlossSimClockIO) Result ()
visualisationRhine :: Rhine (GlossConcT IO) (GlossClockUTC IO GlossSimClockIO) Result ()
visualisationRhine = hoistClSF sampleIOGloss visualisation @@ glossClockUTC GlossSimClockIO

{- FOURMOLU_DISABLE -}
Expand All @@ -435,6 +418,33 @@ mainMultiRate =
launchInGlossThread glossSettings $
flow mainRhineMultiRate

-- ** Multi-rate: Inference in separate buffer

mainRhineMultiRateInferenceBuffer =
userTemperature
@@ glossClockUTC GlossEventClockIO
>-- keepLast initialTemperature
--> modelRhine
@>>^ (\(temperature, (sensor, pos)) -> (sensor, (temperature, sensor, pos)))
>-- hoistResamplingBuffer sampleIOGloss (inferenceBuffer nParticles resampleSystematic (temperatureProcess >-> (prior &&& clId)) (\(pos, _) sensor -> sensorLikelihood pos sensor))
*-* keepLast (initialTemperature, zeroVector, zeroVector)
--> ( \(particles, (temperature, measured, latent)) ->
Result
{ temperature
, measured
, latent
, particlesPosition = second (const (1 / fromIntegral nParticles)) <$> particles
, particlesTemperature = (, 1 / fromIntegral nParticles) . snd <$> particles
}
)
^>>@ visualisationRhine

mainMultiRateInferenceBuffer :: IO ()
mainMultiRateInferenceBuffer =
void $
launchInGlossThread glossSettings $
flow mainRhineMultiRateInferenceBuffer

-- * Utilities

instance (MonadDistribution m) => MonadDistribution (GlossConcT m) where
Expand Down
52 changes: 18 additions & 34 deletions rhine-bayes/rhine-bayes.cabal
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
cabal-version: 2.2
name: rhine-bayes
version: 1.2
synopsis: monad-bayes backend for Rhine
description:
This package provides a backend to the @monad-bayes@ library,
enabling you to write stochastic processes as signal functions,
and performing online machine learning on them.
license: BSD3
license: BSD-3-Clause
license-file: LICENSE
author: Manuel Bärenz
maintainer: programming@manuelbaerenz.de
-- copyright:
category: FRP
build-type: Simple
extra-doc-files: README.md ChangeLog.md
cabal-version: 2.0

source-repository head
type: git
Expand All @@ -24,18 +24,14 @@ source-repository this
location: git@github.com:turion/rhine.git
tag: v1.1

library
exposed-modules:
FRP.Rhine.Bayes
other-modules:
Data.MonadicStreamFunction.Bayes
common opts
ghc-options: -Wall
build-depends: base >= 4.11 && < 4.18
, transformers >= 0.5
, rhine == 1.2
, dunai ^>= 0.11
, log-domain >= 0.12
, monad-bayes ^>= 1.2
hs-source-dirs: src
default-language: Haskell2010
default-extensions:
Arrows
Expand All @@ -45,44 +41,32 @@ library
FlexibleInstances
GeneralizedNewtypeDeriving
MultiParamTypeClasses
NamedFieldPuns
RankNTypes
ScopedTypeVariables
TupleSections
TypeFamilies
TypeOperators

ghc-options: -W
if flag(dev)
ghc-options: -Werror

library
import: opts
exposed-modules:
FRP.Rhine.Bayes
other-modules:
Data.MonadicStreamFunction.Bayes
hs-source-dirs: src

executable rhine-bayes-gloss
import: opts
main-is: Main.hs
hs-source-dirs: app
build-depends: base >= 4.11 && < 4.18
, rhine
, rhine-bayes
, rhine-gloss == 1.2
, dunai
, monad-bayes
, transformers
, log-domain
, mmorph
, time
default-language: Haskell2010
default-extensions:
Arrows
DataKinds
FlexibleContexts
NamedFieldPuns
RankNTypes
TupleSections
TypeApplications
TypeFamilies
TypeOperators

ghc-options: -W -threaded -rtsopts -with-rtsopts=-N
if flag(dev)
ghc-options: -Werror
ghc-options: -threaded -rtsopts -with-rtsopts=-N
build-depends: rhine-bayes
, rhine-gloss == 1.2
, mmorph

flag dev
description: Enable warnings as errors. Active on ci.
Expand Down
46 changes: 42 additions & 4 deletions rhine-bayes/src/FRP/Rhine/Bayes.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE ScopedTypeVariables #-}
module FRP.Rhine.Bayes where

-- transformers
Expand All @@ -7,7 +8,7 @@ import Control.Monad.Trans.Reader (ReaderT (..))
import Numeric.Log hiding (sum)

-- monad-bayes
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Class hiding (posterior)
import Control.Monad.Bayes.Population

-- dunai
Expand All @@ -18,6 +19,8 @@ import qualified Data.MonadicStreamFunction.Bayes as DunaiBayes

-- rhine
import FRP.Rhine
import Data.MonadicStreamFunction.Bayes (runPopulationS)
import GHC.Stack (HasCallStack)

-- * Inference methods

Expand Down Expand Up @@ -108,16 +111,16 @@ wienerVaryingLogDomain = wienerVarying >>> arr Exp
* The output is the number of events since the last tick.
-}
poissonInhomogeneous ::
(MonadDistribution m, Real (Diff td), Fractional (Diff td)) =>
(HasCallStack, MonadDistribution m, Real (Diff td), Fractional (Diff td)) =>
BehaviourF m td (Diff td) Int
poissonInhomogeneous = arrM $ \rate -> ReaderT $ \timeInfo -> poisson $ realToFrac $ sinceLast timeInfo / rate
poissonInhomogeneous = arrM $ \rate -> ReaderT $ \timeInfo -> poisson $ realToFrac $ max 0 (sinceLast timeInfo) / rate

-- | Like 'poissonInhomogeneous', but the rate is constant.
poissonHomogeneous ::
(MonadDistribution m, Real (Diff td), Fractional (Diff td)) =>
-- | The (constant) rate of the process
Diff td ->
BehaviourF m td () Int
Behaviour m td Int
poissonHomogeneous rate = arr (const rate) >>> poissonInhomogeneous

{- | The Gamma process, https://en.wikipedia.org/wiki/Gamma_process.
Expand All @@ -140,3 +143,38 @@ gammaInhomogeneous gamma = proc rate -> do
-}
bernoulliInhomogeneous :: (MonadDistribution m) => BehaviourF m td Double Bool
bernoulliInhomogeneous = arrMCl bernoulli


inferenceBuffer :: forall clA clS time m s a . (TimeDomain time, time ~ Time clS, time ~ Time clA, Monad m, MonadDistribution m)
=> Int ->
(forall n x . MonadDistribution n => PopulationT n x -> PopulationT n x) ->
Behaviour m time s -> (s -> a -> Log Double) -> ResamplingBuffer m clA clS a [s]
inferenceBuffer nParticles resampler process likelihood = msfBuffer' $ runPopulationS nParticles resampler posterior >>> arr (fmap fst)
where
processParClock :: ClSF m (ParallelClock clA clS) () s
processParClock = process
posterior :: Monad m => MSF (PopulationT m) (Either (TimeInfo clS) (TimeInfo clA, a)) s
posterior = proc tia -> do
lastTime <- iPre Nothing -< Just $ either absolute (absolute . fst) tia
let ti = (either (retag Right) (retag Left . fst) tia) { sinceLast = maybe (sinceInit ti) (absolute ti `diffTime`) lastTime }
s <- DunaiReader.runReaderS $ liftClSF processParClock -< (ti, ())
right $ arrM factor -< likelihood s . snd <$> tia
returnA -< s
-- inferenceBuffer nParticles process likelihood = go $ replicate nParticles process
-- where
-- stepToTime :: TimeInfo cl -> ClSF m cl () s -> m (s, ClSF m cl () s)
-- stepToTime ti clsf = second SomeBehaviour <$> runReaderT (unMSF (getSomeBehaviour clsf) ()) ti

-- -- Add resamplnig here
-- stepAllToTime :: Monad m => (forall n . MonadDistribution n => PopulationT n a -> PopulationT n a) -> TimeInfo cl -> [ClSF m cl () s] -> m [(s, ClSF m cl () s)]
-- stepAllToTime resampler ti = fmap _ . runPopulationT . resampler . fromWeightedList . fmap _ . mapM (stepToTime ti)

-- go :: [ClSF m (ParallelClock clA clB) () s] -> ResamplingBuffer m clA clS a s
-- go msfs = ResamplingBuffer
-- { put = \ti a -> do
-- stepped <- forM msfs $ \msf -> do
-- msf' <- stepToTime (retag Left ti) msf
-- _ -- factor each msf individually, resample all at the end
-- return $ go $ snd <$> stepped
-- , get = _
-- }
Loading
Loading