Skip to content

Commit

Permalink
small tweak in full report
Browse files Browse the repository at this point in the history
  • Loading branch information
folivetti committed Sep 2, 2024
1 parent 473b512 commit 0b62439
Show file tree
Hide file tree
Showing 31 changed files with 220 additions and 145 deletions.
65 changes: 54 additions & 11 deletions apps/srtools/IO.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# language BlockArguments #-}
{-# language LambdaCase #-}
module IO where

import System.IO ( hClose, hPutStrLn, openFile, stderr, stdout, IOMode(WriteMode), Handle )
Expand All @@ -7,14 +8,15 @@ import Data.List ( intercalate )
import Control.Monad ( unless, forM_ )
import System.Random ( StdGen )

import Data.SRTree ( SRTree, Fix (..), floatConstsToParam )
import Data.SRTree ( SRTree (..), Fix (..), var, floatConstsToParam )
import Algorithm.SRTree.Opt ( estimateSErr )
import Algorithm.SRTree.Likelihoods ( Distribution (..) )
import Algorithm.SRTree.ConfidenceIntervals ( printCI, BasicStats(_stdErr, _corr), CI )
import qualified Data.SRTree.Print as P

import Args ( Args(outfile, alpha,msErr,dist,niter) )
import Report
import Data.SRTree.Recursion ( cata )

import Debug.Trace ( trace, traceShow )

Expand Down Expand Up @@ -54,23 +56,23 @@ processTree args seed dset tree ix = (basic, sseOrig, sseOpt, info, cis)
cis = getCI args' dset basic (alpha args')

-- print the results to a csv format (except CI)
printResults :: Args -> StdGen -> Datasets -> [Either String (Fix SRTree)] -> IO ()
printResults args seed dset exprs = do
printResults :: Args -> StdGen -> Datasets -> [String] -> [Either String (Fix SRTree)] -> IO ()
printResults args seed dset varnames exprs = do
hStat <- openWriteWithDefault stdout (outfile args)
hPutStrLn hStat csvHeader
forM_ (zip [0..] exprs)
\(ix, tree) ->
case tree of
Left err -> hPutStrLn stderr ("invalid expression: " <> err)
Right t -> let treeData = processTree args seed dset t ix
in hPutStrLn hStat (toCsv treeData)
in hPutStrLn hStat (toCsv treeData varnames)
unless (null (outfile args)) (hClose hStat)

-- change the stats into a string
toCsv :: (BasicInfo, SSE, SSE, Info, e) -> String
toCsv (basic, sseOrig, sseOpt, info, _) = intercalate "," (sBasic <> sSSEOrig <> sSSEOpt <> sInfo)
toCsv :: (BasicInfo, SSE, SSE, Info, e) -> [String] -> String
toCsv (basic, sseOrig, sseOpt, info, _) varnames = intercalate "," (sBasic <> sSSEOrig <> sSSEOpt <> sInfo)
where
sBasic = [ show (_index basic), show (_fname basic), P.showExpr (_expr basic)
sBasic = [ show (_index basic), show (_fname basic), P.showExprWithVars varnames (_expr basic)
, show (_nNodes basic), show (_nParams basic)
, intercalate ";" (map show (_params basic))
]
Expand All @@ -80,9 +82,39 @@ toCsv (basic, sseOrig, sseOpt, info, _) = intercalate "," (sBasic <> sSSEOrig <>
<> [intercalate ";" (map show (_fisher info))]
showF p f = show (f p)

-- get trees of transformed features
getTransformedFeatures :: Fix SRTree -> (Fix SRTree, [Fix SRTree])
getTransformedFeatures = cata $
\case
Var ix -> (Fix $ Var ix, [])
Param ix -> (Fix $ Param ix, [])
Const x -> (Fix $ Const x, [])
Uni f (t, vars) -> (Fix $ Uni f t, vars)
Bin op (l, vs1) (r, vs2) -> case (hasNoParam l, hasNoParam r) of
(False, True) -> let vs = vs1 <> vs2
in (Fix $ Bin op l (var $ length vs), vs <> [r])
(True, False) -> let vs = vs1 <> vs2
in (Fix $ Bin op (var $ length vs) r, vs <> [l])
( _, _) -> (Fix $ Bin op l r, vs1 <> vs2) -- vs1 == vs2 == []

where
hasNoParam = cata $
\case
Var ix -> True
Param ix -> False
Const x -> if floor x == ceiling x then True else False
Uni f t -> t
Bin op l r -> l && r

allAreVars :: [Fix SRTree] -> Bool
allAreVars = all isOnlyVar
where
isOnlyVar (Fix (Var _)) = True
isOnlyVar _ = False

-- print the information on screen (including CIs)
printResultsScreen :: Args -> StdGen -> Datasets -> [Either String (Fix SRTree)] -> IO ()
printResultsScreen args seed dset exprs = do
printResultsScreen :: Args -> StdGen -> Datasets -> [String] -> String -> [Either String (Fix SRTree)] -> IO ()
printResultsScreen args seed dset varnames targt exprs = do
forM_ (zip [0..] exprs)
\(ix, tree) ->
case tree of
Expand All @@ -95,9 +127,20 @@ printResultsScreen args seed dset exprs = do
sdecim n = show . decim n
nplaces = 4


printToScreen ix (basic, _, sseOpt, info, (sts, cis, pis_tr, pis_val, pis_te)) =
do putStrLn $ "=================== EXPR " <> show ix <> " =================="
putStrLn $ P.showExpr (_expr basic)
do let (transformedT, newvars) = getTransformedFeatures (_expr basic)
varnames' = ['z': show ix | ix <- [0 .. length newvars - 1]]
putStrLn $ "=================== EXPR " <> show ix <> " =================="
putStr $ targt <> " ~ f(" <> intercalate ", " varnames <> ") = "
putStrLn $ P.showExprWithVars varnames (_expr basic)

unless (allAreVars newvars) do
putStrLn "\nExpression and transformed features: "
putStr $ targt <> " ~ f(" <> intercalate ", " varnames' <> ") = "
putStrLn $ P.showExprWithVars varnames' transformedT
forM_ (zip varnames' newvars) \(vn, tv) -> do
putStrLn $ vn <> " = " <> P.showExprWithVars varnames tv

putStrLn "\n---------General stats:---------\n"
putStrLn $ "Number of nodes: " <> show (_nNodes basic)
Expand Down
10 changes: 6 additions & 4 deletions apps/srtools/Main.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module Main (main) where

import Data.ByteString.Char8 ( pack, unpack, split )
import Options.Applicative
import Text.ParseSR.IO ( withInput )
import System.Random ( getStdGen, mkStdGen )
import Text.ParseSR.IO ( withInput )

import Args
import IO
Expand All @@ -12,14 +13,15 @@ main :: IO ()
main = do
args <- execParser opts
g <- getStdGen
(dset, varnames) <- getDataset args
(dset, varnames, tgname) <- getDataset args
let seed = if rseed args < 0
then g
else mkStdGen (rseed args)
varnames' = map unpack $ split ',' $ pack varnames
withInput (infile args) (from args) varnames False (simpl args)
>>= if toScreen args
then printResultsScreen args seed dset -- full report on screne
else printResults args seed dset -- csv file
then printResultsScreen args seed dset varnames' tgname -- full report on screne
else printResults args seed dset varnames' -- csv file
where
opts = info (opt <**> helper)
( fullDesc <> progDesc "Optimize the parameters of\
Expand Down
8 changes: 4 additions & 4 deletions apps/srtools/Report.hs
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,18 @@ data Info = Info { _bic :: Double
}

-- load the datasets
getDataset :: Args -> IO (Datasets, String)
getDataset :: Args -> IO (Datasets, String, String)
getDataset args = do
((xTr, yTr, xVal, yVal), varnames) <- loadDataset (dataset args) (hasHeader args)
((xTr, yTr, xVal, yVal), varnames, tgname) <- loadDataset (dataset args) (hasHeader args)
let (A.Sz m) = A.size yVal
let (mXVal, mYVal) = if m == 0
then (Nothing, Nothing)
else (Just xVal, Just yVal)
(mXTe, mYTe) <- if null (test args)
then pure (Nothing, Nothing)
else do ((xTe, yTe, _, _), _) <- loadDataset (test args) (hasHeader args)
else do ((xTe, yTe, _, _), _, _) <- loadDataset (test args) (hasHeader args)
pure (Just xTe, Just yTe)
pure (DS xTr yTr mXVal mYVal mXTe mYTe, varnames)
pure (DS xTr yTr mXVal mYVal mXTe mYTe, varnames, tgname)

getBasicStats :: Args -> StdGen -> Datasets -> Fix SRTree -> Int -> BasicInfo
getBasicStats args seed dset tree ix
Expand Down
2 changes: 1 addition & 1 deletion apps/tinygp/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ main :: IO ()
main = do
args <- execParser opts
g <- getStdGen
((x, y, _, _), _) <- loadDataset (dataset args) True
((x, y, _, _), _, _) <- loadDataset (dataset args) True
let hp = HP 2 4 25 (popSize args) 2 (pc args) (pm args) terms nonterms
(Sz2 _ nFeats) = size x
terms = [var ix | ix <- [0 .. nFeats-1]] <> [param ix | ix <- [0 .. 5]]
Expand Down
6 changes: 3 additions & 3 deletions docs/Algorithm-EqSat-Egraph.html

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions docs/Algorithm-EqSat-EqSatDB.html

Large diffs are not rendered by default.

Loading

0 comments on commit 0b62439

Please sign in to comment.