diff --git a/monad-bayes.cabal b/monad-bayes.cabal index 04b58dd1..7c473a78 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -162,6 +162,7 @@ test-suite monad-bayes-test HMM LDA LogReg + NonlinearSSM Sprinkler TestAdvanced TestBenchmarks @@ -173,6 +174,7 @@ test-suite monad-bayes-test TestPopulation TestSampler TestSequential + TestSSMFixtures TestStormerVerlet TestWeighted diff --git a/test/Spec.hs b/test/Spec.hs index baaa95d1..f5192155 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -13,6 +13,7 @@ import TestIntegrator qualified import TestPipes (hmms) import TestPipes qualified import TestPopulation qualified +import TestSSMFixtures qualified import TestSampler qualified import TestSequential qualified import TestStormerVerlet qualified @@ -167,3 +168,4 @@ main = hspec do passed7 `shouldBe` True TestBenchmarks.test + TestSSMFixtures.test diff --git a/test/TestSSMFixtures.hs b/test/TestSSMFixtures.hs new file mode 100644 index 00000000..86298a6d --- /dev/null +++ b/test/TestSSMFixtures.hs @@ -0,0 +1,80 @@ +module TestSSMFixtures where + +import Control.Monad.Bayes.Class (MonadDistribution) +import Control.Monad.Bayes.Inference.MCMC +import Control.Monad.Bayes.Inference.PMMH as PMMH (pmmh) +import Control.Monad.Bayes.Inference.RMSMC (rmsmc, rmsmcBasic, rmsmcDynamic) +import Control.Monad.Bayes.Inference.SMC +import Control.Monad.Bayes.Inference.SMC2 as SMC2 (smc2) +import Control.Monad.Bayes.Population +import Control.Monad.Bayes.Sampler.Strict (sampleIOfixed) +import Control.Monad.Bayes.Weighted (unweighted) +import NonlinearSSM +import System.IO (readFile') +import System.IO.Error (catchIOError, isDoesNotExistError) +import Test.Hspec + +data Alg = SMC | RMSMC | RMSMCDynamic | RMSMCBasic | PMMH | SMC2 + deriving (Show, Read, Eq, Ord, Enum, Bounded) + +algs :: [Alg] +algs = [minBound .. maxBound] + +fixtureToFilename :: Alg -> FilePath +fixtureToFilename alg = "test/fixtures/SSM-" ++ show alg ++ ".txt" + +type SSMData = [Double] + +t :: Int +t = 5 + +-- FIXME refactor such that it can be reused in ssm benchmark +runAlgFixed :: MonadDistribution m => SSMData -> Alg -> m String +runAlgFixed ys SMC = fmap show $ population $ smc SMCConfig {numSteps = t, numParticles = 10, resampler = resampleMultinomial} (param >>= model ys) +runAlgFixed ys RMSMC = + fmap show $ + population $ + rmsmc + MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH} + SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic} + (param >>= model ys) +runAlgFixed ys RMSMCBasic = + fmap show $ + population $ + rmsmcBasic + MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH} + SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic} + (param >>= model ys) +runAlgFixed ys RMSMCDynamic = + fmap show $ + population $ + rmsmcDynamic + MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH} + SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic} + (param >>= model ys) +runAlgFixed ys PMMH = + fmap show $ + unweighted $ + pmmh + MCMCConfig {numMCMCSteps = 2, numBurnIn = 0, proposal = SingleSiteMH} + SMCConfig {numSteps = t, numParticles = 3, resampler = resampleSystematic} + param + (model ys) +runAlgFixed ys SMC2 = fmap show $ population $ smc2 t 3 2 1 param (model ys) + +testFixture :: Alg -> SpecWith () +testFixture alg = do + let filename = fixtureToFilename alg + it ("should agree with the fixture " ++ filename) $ do + ys <- sampleIOfixed $ generateData t + fixture <- catchIOError (readFile' filename) $ \e -> + if isDoesNotExistError e + then return "" + else ioError e + sampled <- sampleIOfixed $ runAlgFixed (map fst ys) alg + -- Reset in case of fixture update or creation + writeFile filename sampled + fixture `shouldBe` sampled + +test :: SpecWith () +test = describe "TestSSMFixtures" $ mapM_ testFixture algs diff --git a/test/fixtures/SSM-PMMH.txt b/test/fixtures/SSM-PMMH.txt new file mode 100644 index 00000000..99c8097d --- /dev/null +++ b/test/fixtures/SSM-PMMH.txt @@ -0,0 +1 @@ +[[([74405.69500410178,143777.3515026691,195290.64675632896,483878.28639985673,600603.4104497777],1.0),([74405.69500410178,143777.3515026691,195290.64675632896,483878.28639985673,600603.4104497777],1.0),([74405.69500410178,143777.3515026691,195290.64675632896,483878.28639985673,600603.4104497777],1.0)],[([157620.04097610444,26661.523636563594,321204.4219216401,421274.0528523404,487363.3134055787],1.0),([157620.04097610444,26661.523636563594,321204.4219216401,421274.0528523404,487363.3134055787],1.0),([157620.04097610444,26661.523636563594,321204.4219216401,421274.0528523404,487363.3134055787],1.0)],[([-1.2600621067470811e67,-1.3171618074660135e67,3.55155213532486e66,-9.486041679240111e66,-1.4476178361450074e67],1.0),([-1.2600621067470811e67,-1.3171618074660135e67,3.55155213532486e66,-9.486041679240111e66,-1.4476178361450074e67],1.0),([-1.2600621067470811e67,-1.3171618074660135e67,3.55155213532486e66,-9.486041679240111e66,-1.4476178361450074e67],1.0)]] \ No newline at end of file diff --git a/test/fixtures/SSM-RMSMC.txt b/test/fixtures/SSM-RMSMC.txt new file mode 100644 index 00000000..f218833b --- /dev/null +++ b/test/fixtures/SSM-RMSMC.txt @@ -0,0 +1 @@ +[([-2.660097878548362e12,4.2406657192899927e11,5.124397279021509e11,-1.5388049692223555e12,5.100591467413694e11],1.7934178371940385e-141),([85360.89249769927,82937.71043376798,236528.18865034508,184344.32920611685,257818.8194711436],1.7934178371940385e-141),([111456.98861817457,175718.5311026478,-123025.40752936345,-216315.66615773254,-216313.3788230968],1.7934178371940385e-141),([19634.90868499715,3982.228157420627,71064.25154723959,-61863.47069790381,-124170.21806752698],1.7934178371940385e-141),([-1.5475119433556266e13,7.505784271326075e13,1.8087989772377637e13,2.3121907140178836e13,6.957846697896459e13],1.7934178371940385e-141),([3.3114425601587117e13,-7.174858488587559e12,-5.3895766067097984e13,-5.574198391310134e13,-7.085806871739197e13],1.7934178371940385e-141),([1.7797637411011798e9,2.257078068208236e9,1.7822875201019692e9,-3.1059378475534713e8,-2.022387835329388e8],1.7934178371940385e-141),([-362.4859209624216,597.209862141266,354.2417236032639,699.0287190356412,-719.9079388224256],1.7934178371940385e-141),([-6962602.171312882,4072501.6094401204,-1935864.3227236385,6378383.05623946,2107530.9972696295],1.7934178371940385e-141),([535.0932822165735,229.4681249747279,-318.4034398226934,-425.6731795563692,-736.158075205745],1.7934178371940385e-141)] \ No newline at end of file diff --git a/test/fixtures/SSM-RMSMCBasic.txt b/test/fixtures/SSM-RMSMCBasic.txt new file mode 100644 index 00000000..f218833b --- /dev/null +++ b/test/fixtures/SSM-RMSMCBasic.txt @@ -0,0 +1 @@ +[([-2.660097878548362e12,4.2406657192899927e11,5.124397279021509e11,-1.5388049692223555e12,5.100591467413694e11],1.7934178371940385e-141),([85360.89249769927,82937.71043376798,236528.18865034508,184344.32920611685,257818.8194711436],1.7934178371940385e-141),([111456.98861817457,175718.5311026478,-123025.40752936345,-216315.66615773254,-216313.3788230968],1.7934178371940385e-141),([19634.90868499715,3982.228157420627,71064.25154723959,-61863.47069790381,-124170.21806752698],1.7934178371940385e-141),([-1.5475119433556266e13,7.505784271326075e13,1.8087989772377637e13,2.3121907140178836e13,6.957846697896459e13],1.7934178371940385e-141),([3.3114425601587117e13,-7.174858488587559e12,-5.3895766067097984e13,-5.574198391310134e13,-7.085806871739197e13],1.7934178371940385e-141),([1.7797637411011798e9,2.257078068208236e9,1.7822875201019692e9,-3.1059378475534713e8,-2.022387835329388e8],1.7934178371940385e-141),([-362.4859209624216,597.209862141266,354.2417236032639,699.0287190356412,-719.9079388224256],1.7934178371940385e-141),([-6962602.171312882,4072501.6094401204,-1935864.3227236385,6378383.05623946,2107530.9972696295],1.7934178371940385e-141),([535.0932822165735,229.4681249747279,-318.4034398226934,-425.6731795563692,-736.158075205745],1.7934178371940385e-141)] \ No newline at end of file diff --git a/test/fixtures/SSM-RMSMCDynamic.txt b/test/fixtures/SSM-RMSMCDynamic.txt new file mode 100644 index 00000000..dd289ae5 --- /dev/null +++ b/test/fixtures/SSM-RMSMCDynamic.txt @@ -0,0 +1 @@ +[([61234.923743603955,79039.83817954235,354024.81636628765,225755.73057039993,-78843.37322818518],0.0),([61234.923743603955,-205024.66324964678,-438520.7645656072,-526045.6062936985,17959.08713788638],0.0),([61234.923743603955,130707.33129683959,260276.7204227042,538891.1815485102,432537.1717560617],0.0),([61234.923743603955,425968.16738967673,72802.89417099475,97062.29318414515,-90904.59187690681],0.0),([61234.923743603955,-80888.00179367141,122235.67304475381,48742.27626015559,-149682.32933231423],0.0),([61234.923743603955,43.833902800088254,-417728.0201965655,49565.634594935604,-303943.3354524304],0.0),([61234.923743603955,350501.69936972257,118986.06426751378,99950.78931739656,-60488.53431816819],0.0),([61234.923743603955,-117376.5868812376,116017.94360094423,378976.39475725644,74865.6296219704],0.0),([61234.923743603955,156368.9791422615,-586653.2615030725,-238480.82081038723,51581.15175237715],0.0),([61234.923743603955,-150776.59937461224,-30862.03908288705,200382.13919586508,-107135.36343350058],0.0)] \ No newline at end of file diff --git a/test/fixtures/SSM-SMC.txt b/test/fixtures/SSM-SMC.txt new file mode 100644 index 00000000..ea2840e4 --- /dev/null +++ b/test/fixtures/SSM-SMC.txt @@ -0,0 +1 @@ +[([-1.6946443595984358e8,-2.0398900541476977e8,5.988104418627801e8,5.186441087015647e7,-1.1107580460544899e9],3.747925572660412e-147),([-1.6946443595984358e8,1.762322765772586e8,1.3143034131110222e9,2.917359439754021e7,-4.678360689283452e8],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,-1.0039010467494124e9,-4.5194462919398534e8],3.747925572660412e-147),([-1.6946443595984358e8,-2.0398900541476977e8,6.65407483202134e8,-1.3610874802534976e9,1.7804869696064534e9],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,-1.0039010467494124e9,-4.5194462919398534e8],3.747925572660412e-147),([-1.6946443595984358e8,-7.848111477226721e8,-1.536250656089418e9,-1.2593852525318892e9,9.33478070563457e8],3.747925572660412e-147),([-1.6946443595984358e8,1.762322765772586e8,1.3143034131110222e9,2.917359439754021e7,-4.678360689283452e8],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,7.201669253635451e8,-6.528627637915363e8],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,7.201669253635451e8,-6.528627637915363e8],3.747925572660412e-147),([-1.6946443595984358e8,-7.848111477226721e8,-1.536250656089418e9,-1.2593852525318892e9,9.33478070563457e8],3.747925572660412e-147)] \ No newline at end of file diff --git a/test/fixtures/SSM-SMC2.txt b/test/fixtures/SSM-SMC2.txt new file mode 100644 index 00000000..2f49282b --- /dev/null +++ b/test/fixtures/SSM-SMC2.txt @@ -0,0 +1 @@ +[([([-9090.483553160731,-18364.240866577857,38447.317849110055,-3829.950678281628,-18689.938602553048],0.3333333333333341),([-9090.483553160731,-18364.240866577857,38447.317849110055,-3829.950678281628,-18689.938602553048],0.3333333333333341),([-9090.483553160731,-18364.240866577857,38447.317849110055,25131.836867847727,47603.03068211828],0.3333333333333341)],9.474658864966518e-180),([([-2.2418721864335723e9,2.4219211687208967e9,5.4463793547824545e9,7.074651672385337e8,2.595090695345872e8],0.3333333333333341),([-2.2418721864335723e9,2.4219211687208967e9,5.4463793547824545e9,7.074651672385337e8,2.595090695345872e8],0.3333333333333341),([-2.2418721864335723e9,2.4219211687208967e9,5.4463793547824545e9,7.074651672385337e8,2.595090695345872e8],0.3333333333333341)],9.474658864966518e-180)] \ No newline at end of file