Skip to content

Commit

Permalink
test release action
Browse files Browse the repository at this point in the history
  • Loading branch information
folivetti committed Sep 4, 2024
1 parent 561111c commit 10add93
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 60 deletions.
4 changes: 2 additions & 2 deletions apps/srtools/IO.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Data.List ( intercalate )
import Control.Monad ( unless, forM_ )
import System.Random ( StdGen )

import Data.SRTree ( SRTree (..), Fix (..), var, floatConstsToParam )
import Data.SRTree ( SRTree (..), Fix (..), var, floatConstsToParam, relabelVars )
import Algorithm.SRTree.Opt ( estimateSErr )
import Algorithm.SRTree.Likelihoods ( Distribution (..) )
import Algorithm.SRTree.ConfidenceIntervals ( printCI, BasicStats(_stdErr, _corr), CI )
Expand Down Expand Up @@ -138,7 +138,7 @@ printResultsScreen args seed dset varnames targt exprs = do
unless (allAreVars newvars) do
putStrLn "\nExpression and transformed features: "
putStr $ targt <> " ~ f(" <> intercalate ", " varnames' <> ") = "
putStrLn $ P.showExprWithVars varnames' transformedT
putStrLn $ P.showExprWithVars varnames' (relabelVars transformedT)
forM_ (zip varnames' newvars) \(vn, tv) -> do
putStrLn $ vn <> " = " <> P.showExprWithVars varnames tv

Expand Down
3 changes: 2 additions & 1 deletion apps/tinygp/GP.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Control.Monad (when)
import Data.Massiv.Array qualified as M
import Debug.Trace ( traceShow, trace )

data Method = Grow | Full
data Method = Grow | Full | BTC

type Rng a = StateT StdGen IO a
type GenUni = Fix SRTree -> Fix SRTree
Expand Down Expand Up @@ -158,6 +158,7 @@ isAbs _ = False
isInv (Fix (Bin Div (Fix (Const 1.0)) _)) = True
isInv _ = False
{-# INLINE isInv #-}

mutate :: HyperParams -> Individual -> Rng Individual
mutate hp ind = do
let sz = countNodes' (_tree ind)
Expand Down
50 changes: 50 additions & 0 deletions apps/tinygp/Initialization.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module Initialization where

data InitiMethod = GROW | FULL | BTC | HALFHALF
{-
btc = undefined
def btc(pset_, depth_, length_, type_=None):
if type_ is None:
type_ = pset_.ret
expr = []
arities = list(map(lambda x: x.arity, pset_.primitives[type_]))
minFunctionArity = min(arities)
maxFunctionArity = max(arities)
# adapt length to restrictions of the primitive set
if length_ % 2 == 0 and minFunctionArity > 1:
length_ = length_ + 1 if np.random.random_sample(1) > 0.5 else length_ - 1
targetLength = length_ - 1 # don't count the root node
maxFunctionArity = min(maxFunctionArity, targetLength)
minFunctionArity = min(minFunctionArity, targetLength)
root = sampleChild(pset_, minFunctionArity, maxFunctionArity, type_)
# inner lists of the form [node, depth, childIndex]
# childIndex is only used at the end to transform
# the representation from breadth to prefix
expr.append([root, 0, 1])
openSlots = root.arity
for i in range(0, length_):
(node, nodeDepth, childIndex) = expr[i]
childDepth = nodeDepth + 1
for j in range(0, getArity(node)):
maxArity = 0 if childDepth == depth_ - 1 else min(maxFunctionArity, targetLength - openSlots)
minArity = min(minFunctionArity, maxArity)
child = sampleChild(pset_, minArity, maxArity, type_)
if j == 0:
expr[i][2] = len(expr)
expr.append([child, childDepth, 0])
openSlots += getArity(child)
nodes = breadthToPrefix(expr)
return nodes
-}
2 changes: 2 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ dependencies:

library:
source-dirs: src
ghc-options:
- -fwarn-incomplete-patterns

executables:
tinygp:
Expand Down
42 changes: 35 additions & 7 deletions src/Algorithm/SRTree/AD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ applyBin op (Left ly) (Left ry) =
Mul -> ly !*! ry
Div -> ly !/! ry
Power -> ly .** ry
PowerAbs -> M.map abs (ly .** ry)
AQ -> ly !/! (M.map sqrt (M.map (+1) (ry !*! ry)))

applyBin op (Left ly) (Right ry) =
Left $ unsafeLiftArray (\ x -> evalOp op x ry) ly
Expand Down Expand Up @@ -145,12 +147,14 @@ forwardMode xss theta err tree = let (yhat, jacob) = runST $ cataM lToR alg tree
forM_ [0 .. p-1] $ \j -> do
vl <- UMA.unsafeRead tl (i :. j)
vr <- UMA.unsafeRead tr (i :. j)
case op of
Add -> UMA.unsafeWrite tl (i :. j) (vl+vr)
Sub -> UMA.unsafeWrite tl (i :. j) (vl-vr)
Mul -> UMA.unsafeWrite tl (i :. j) (vl * ri + vr * li)
Div -> UMA.unsafeWrite tl (i :. j) ((vl * ri - vr * li) / ri^2)
Power -> UMA.unsafeWrite tl (i :. j) (li ** (ri - 1) * (ri * vl + li * log li * vr))
UMA.unsafeWrite tl (i :. j) $ case op of
Add -> (vl+vr)
Sub -> (vl-vr)
Mul -> (vl * ri + vr * li)
Div -> ((vl * ri - vr * li) / ri^2)
Power -> (li ** (ri - 1) * (ri * vl + li * log li * vr))
PowerAbs -> (abs li ** ri) * (vr * log (abs li) + ri * vl / li)
AQ -> ((1 + ri*ri) * vl - li * ri * vr) / (1 + ri*ri) ** 1.5
tlF <- UMA.unsafeFreeze cmp tl
pure (applyBin op l r, tlF)

Expand Down Expand Up @@ -183,7 +187,14 @@ forwardModeUnique xss theta err = second (toGrad . DL.toList) . cata alg
alg (Bin Power (v1, l) (v2, r)) = let dv1 = v1 ** (v2 - one)
dv2 = v1 * log v1
in (v1 ** v2, DL.map (*dv1) (DL.append (DL.map (*v2) l) (DL.map (*dv2) r)))

alg (Bin PowerAbs (v1, l) (v2, r)) = let dv1 = abs v1 ** v2
dv2 = DL.map (* (log (abs v1))) r
dv3 = DL.map (*(v2 / v1)) l
in (abs v1 ** v2, DL.map (*dv1) (DL.append dv2 dv3))
alg (Bin AQ (v1, l) (v2, r)) = let dv1 = DL.map (*(1 + v2*v2)) l
dv2 = DL.map (*(-v1*v2)) r
in (v1/sqrt(1 + v2*v2), DL.map (/(1 + v2*v2)**1.5) $ DL.append dv1 dv2)

data TupleF a b = Single a | T a b | Branch a b b deriving Functor -- hi, I'm a tree
type Tuple a = Fix (TupleF a)

Expand Down Expand Up @@ -248,13 +259,23 @@ reverseModeUnique xss theta ys f t = unsafePerformIO $
-- dx is the current derivative so far
-- fx is the evaluation of the left branch
-- gx is the evaluation of the right branch
--
-- this should return a tuple, where the left element is
-- dx * d op(f(x), g(x)) / d f(x) and
-- the right branch dx * d op (f(x), g(x)) / d g(x)
diff Add dx fx gy = (dx, dx)
diff Sub dx fx gy = (dx, negate' dx)
diff Mul dx fx gy = (applyBin Mul dx gy, applyBin Mul dx fx)
diff Div dx fx gy = (applyBin Div dx gy, applyBin Mul dx (applyBin Div (negate' fx) (applyBin Mul gy gy)))
diff Power dx fx gy = let dxl = applyBin Mul dx (applyBin Power fx (applyBin Sub gy (Right 1)))
dv2 = applyBin Mul fx (applyUni Log fx)
in (applyBin Mul dxl gy, applyBin Mul dxl dv2)
diff PowerAbs dx fx gy = let dxl = applyBin Mul (applyBin Mul gy fx) (applyBin PowerAbs fx (applyBin Sub gy (Right 2)))
dxr = applyBin Mul (applyUni LogAbs fx) (applyBin PowerAbs fx gy)
in (applyBin Mul dxl dx, applyBin Mul dxr dx)
diff AQ dx fx gy = let dxl = applyUni Recip (applyUni Sqrt (applyBin Add (applyUni Square gy) (Right 1)))
dxy = applyBin Div (applyBin Mul fx gy) (applyUni Cube (applyUni Sqrt (applyBin Add (applyUni Square gy) (Right 1))))
in (applyBin Mul dxl dx, applyBin Mul dxy dx)


-- once we reach a leaf with a parameter, we return a singleton
Expand Down Expand Up @@ -293,3 +314,10 @@ forwardModeUniqueJac xss theta = snd . second (map (M.computeAs M.S) . DL.toList
alg (Bin Power (v1, l) (v2, r)) = let dv1 = v1 ** (v2 - one)
dv2 = v1 * log v1
in (v1 ** v2, DL.map (*dv1) (DL.append (DL.map (*v2) l) (DL.map (*dv2) r)))
alg (Bin PowerAbs (v1, l) (v2, r)) = let dv1 = abs v1 ** v2
dv2 = DL.map (* (log (abs v1))) r
dv3 = DL.map (*(v2 / v1)) l
in (abs v1 ** v2, DL.map (*dv1) (DL.append dv2 dv3))
alg (Bin AQ (v1, l) (v2, r)) = let dv1 = DL.map (*(1 + v2*v2)) l
dv2 = DL.map (*(-v1*v2)) r
in (v1/sqrt(1 + v2*v2), DL.map (/(1 + v2*v2)**1.5) $ DL.append dv1 dv2)
2 changes: 2 additions & 0 deletions src/Algorithm/SRTree/ModelSelection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ treeToNat = cata $
opToNat Mul = 1.720356134912558
opToNat Div = 2.60436883851265
opToNat Power = 2.527957363394847
opToNat PowerAbs = 2.527957363394847
opToNat AQ = 2.60436883851265

funToNat :: Function -> Double
funToNat Sqrt = 4.780867285331753
Expand Down
2 changes: 2 additions & 0 deletions src/Data/SRTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module Data.SRTree
, numberOfVars
, getIntConsts
, relabelParams
, relabelVars
, constsToParam
, floatConstsToParam
, paramsToConst
Expand All @@ -53,6 +54,7 @@ import Data.SRTree.Internal
, numberOfVars
, getIntConsts
, relabelParams
, relabelVars
, constsToParam
, floatConstsToParam
, paramsToConst
Expand Down
28 changes: 19 additions & 9 deletions src/Data/SRTree/Derivative.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ import Data.Attoparsec.ByteString.Char8 (double)
deriveBy :: Bool -> Int -> Fix SRTree -> Fix SRTree
deriveBy p dx = fst (mutu alg1 alg2)
where
alg1 (Var ix) = if not p && ix == dx then 1 else 0
alg1 (Param ix) = if p && ix == dx then 1 else 0
alg1 (Const _) = 0
alg1 (Uni f t) = derivative f (snd t) * fst t
alg1 (Bin Add l r) = fst l + fst r
alg1 (Bin Sub l r) = fst l - fst r
alg1 (Bin Mul l r) = fst l * snd r + snd l * fst r
alg1 (Bin Div l r) = (fst l * snd r - snd l * fst r) / snd r ** 2
alg1 (Bin Power l r) = snd l ** (snd r - 1) * (snd r * fst l + snd l * log (snd l) * fst r)
alg1 (Var ix) = if not p && ix == dx then 1 else 0
alg1 (Param ix) = if p && ix == dx then 1 else 0
alg1 (Const _) = 0
alg1 (Uni f t) = derivative f (snd t) * fst t
alg1 (Bin Add l r) = fst l + fst r
alg1 (Bin Sub l r) = fst l - fst r
alg1 (Bin Mul l r) = fst l * snd r + snd l * fst r
alg1 (Bin Div l r) = (fst l * snd r - snd l * fst r) / snd r ** 2
alg1 (Bin Power l r) = snd l ** (snd r - 1) * (snd r * fst l + snd l * log (snd l) * fst r)
alg1 (Bin PowerAbs l r) = (abs (snd l) ** (snd r)) * (fst r * log (abs (snd l)) + snd r * fst l / snd l)
alg1 (Bin AQ l r) = ((1 + snd r * snd r) * fst l - snd l * snd r * fst r) / (1 + snd r * snd r) ** 1.5

alg2 (Var ix) = var ix
alg2 (Param ix) = param ix
Expand Down Expand Up @@ -72,10 +74,14 @@ derivative ASinh = recip . sqrt . (1+) . (^2)
derivative ACosh = \x -> 1 / (sqrt (x-1) * sqrt (x+1))
derivative ATanh = recip . (1-) . (^2)
derivative Sqrt = recip . (2*) . sqrt
derivative SqrtAbs = \x -> x / (2.0 * abs x ** (3.0/2.0))
derivative Cbrt = recip . (3*) . (**(1/3)) . (^2)
derivative Square = (2*)
derivative Exp = exp
derivative Log = recip
derivative LogAbs = recip
derivative Recip = negate . recip . (^2)
derivative Cube = (3*) . (^2)
{-# INLINE derivative #-}

-- | Second-order derivative of supported functions
Expand All @@ -98,10 +104,14 @@ doubleDerivative ASinh = \x -> x / (x^2 + 1)**(3/2) -- check
doubleDerivative ACosh = \x -> 1 / (sqrt (x-1) * sqrt (x+1)) -- check
doubleDerivative ATanh = recip . (1-) . (^2) -- check
doubleDerivative Sqrt = \x -> -1 / (4 * sqrt x^3)
doubleDerivative SqrtAbs = \x -> (-x)*x/(4 * abs x ** (3.5))
doubleDerivative Cbrt = \x -> -2 / (9 * x * (x^2)**(1/3))
doubleDerivative Square = const 2
doubleDerivative Exp = exp
doubleDerivative Log = negate . recip . (^2)
doubleDerivative LogAbs = negate . recip . (^2)
doubleDerivative Recip = (*2) . recip . (^3)
doubleDerivative Cube = (6*)
{-# INLINE doubleDerivative #-}

-- | Symbolic derivative by a variable
Expand Down
17 changes: 16 additions & 1 deletion src/Data/SRTree/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ evalOp Sub = (-)
evalOp Mul = (*)
evalOp Div = (/)
evalOp Power = (**)
evalOp PowerAbs = \l r -> abs l ** r
evalOp AQ = \l r -> l / sqrt(1 + r*r)
{-# INLINE evalOp #-}

-- evaluates a function
Expand All @@ -114,10 +116,14 @@ evalFun ASinh = asinh
evalFun ACosh = acosh
evalFun ATanh = atanh
evalFun Sqrt = sqrt
evalFun SqrtAbs = sqrt . abs
evalFun Cbrt = cbrt
evalFun Square = (^2)
evalFun Log = log
evalFun LogAbs = log . abs
evalFun Exp = exp
evalFun Recip = recip
evalFun Cube = (^3)
{-# INLINE evalFun #-}

-- Cubic root
Expand Down Expand Up @@ -145,6 +151,7 @@ inverseFunc Square = Sqrt
-- inverseFunc Cbrt = (^3)
inverseFunc Log = Exp
inverseFunc Exp = Log
inverseFunc Recip = Recip
-- inverseFunc Abs = Abs -- we assume abs(x) = sqrt(x^2) so y = sqrt(x^2) => x^2 = y^2 => x = sqrt(y^2) = x = abs(y)
inverseFunc x = error $ show x ++ " has no support for inverse function"
{-# INLINE inverseFunc #-}
Expand All @@ -165,11 +172,15 @@ evalInverse ASinh = sinh
evalInverse ACosh = cosh
evalInverse ATanh = tanh
evalInverse Sqrt = (^2)
evalInverse SqrtAbs = (^2)
evalInverse Square = sqrt
evalInverse Cbrt = (^3)
evalInverse Log = exp
evalInverse LogAbs = exp
evalInverse Exp = log
evalInverse Abs = abs -- we assume abs(x) = sqrt(x^2) so y = sqrt(x^2) => x^2 = y^2 => x = sqrt(y^2) = x = abs(y)
evalInverse Recip = recip
evalInverse Cube = cbrt

-- | evals the right inverse of an operator
invright :: Floating a => Op -> a -> (a -> a)
Expand All @@ -178,6 +189,8 @@ invright Sub v = (+v)
invright Mul v = (/v)
invright Div v = (*v)
invright Power v = (**(1/v))
invright PowerAbs v = (**(1/v))
invright AQ v = (* sqrt (1 + v*v))

-- | evals the left inverse of an operator
invleft :: Floating a => Op -> a -> (a -> a)
Expand All @@ -186,7 +199,9 @@ invleft Sub v = (+v) . negate -- y = v - r => r = v - y
invleft Mul v = (/v)
invleft Div v = (v/) -- y = v / r => r = v/y
invleft Power v = logBase v -- (/(log v)) . log -- y = v ^ r log y = r log v r = log y / log v
invleft PowerAbs v = logBase v . abs
invleft AQ v = (v/)

-- | List of invertible functions
invertibles :: [Function]
invertibles = [Id, Sin, Cos, Tan, Tanh, ASin, ACos, ATan, ATanh, Sqrt, Square, Log, Exp]
invertibles = [Id, Sin, Cos, Tan, Tanh, ASin, ACos, ATan, ATanh, Sqrt, Square, Log, Exp, Recip]
33 changes: 32 additions & 1 deletion src/Data/SRTree/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ module Data.SRTree.Internal
, numberOfVars
, getIntConsts
, relabelParams
, relabelVars
, constsToParam
, floatConstsToParam
, paramsToConst
Expand All @@ -59,7 +60,7 @@ data SRTree val =
deriving (Show, Eq, Ord, Functor)

-- | Supported operators
data Op = Add | Sub | Mul | Div | Power
data Op = Add | Sub | Mul | Div | Power | PowerAbs | AQ
deriving (Show, Read, Eq, Ord, Enum)

-- | Supported functions
Expand All @@ -79,10 +80,14 @@ data Function =
| ACosh
| ATanh
| Sqrt
| SqrtAbs
| Cbrt
| Square
| Log
| LogAbs
| Exp
| Recip
| Cube
deriving (Show, Read, Eq, Ord, Enum)

-- | create a tree with a single node representing a variable
Expand Down Expand Up @@ -153,6 +158,9 @@ instance Fractional (Fix SRTree) where
l / r = Fix $ Bin Div l r
{-# INLINE (/) #-}

recip = Fix . Uni Recip
{-# INLINE recip #-}

fromRational = Fix . Const . fromRational
{-# INLINE fromRational #-}

Expand Down Expand Up @@ -359,6 +367,29 @@ relabelParams t = cataM leftToRight alg t `evalState` 0
alg (Uni f t) = pure $ Fix (Uni f t)
alg (Bin f l r) = pure $ Fix (Bin f l r)

-- | Relabel the parameters indices incrementaly starting from 0
--
-- >>> showExpr . relabelParams $ "x0" + "t0" * sin ("t1" + "x1") - "t0"
-- "x0" + "t0" * sin ("t1" + "x1") - "t2"
relabelVars :: Fix SRTree -> Fix SRTree
relabelVars t = cataM leftToRight alg t `evalState` 0
where
-- | leftToRight (left to right) defines the sequence of processing
leftToRight (Uni f mt) = Uni f <$> mt;
leftToRight (Bin f ml mr) = Bin f <$> ml <*> mr
leftToRight (Var ix) = pure (Var ix)
leftToRight (Param ix) = pure (Param ix)
leftToRight (Const c) = pure (Const c)

-- | any time we reach a Param ix, it replaces ix with current state
-- and increments one to the state.
alg :: SRTree (Fix SRTree) -> State Int (Fix SRTree)
alg (Var ix) = do iy <- get; modify (+1); pure (var iy)
alg (Param ix) = pure $ param ix
alg (Const c) = pure $ Fix $ Const c
alg (Uni f t) = pure $ Fix (Uni f t)
alg (Bin f l r) = pure $ Fix (Bin f l r)

-- | Change constant values to a parameter, returning the changed tree and a list
-- of parameter values
--
Expand Down
Loading

0 comments on commit 10add93

Please sign in to comment.