Skip to content

Commit

Permalink
eqgraphGP working but slow
Browse files Browse the repository at this point in the history
  • Loading branch information
folivetti committed Sep 9, 2024
1 parent ba7c3fb commit ae8b7a1
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 20 deletions.
16 changes: 10 additions & 6 deletions apps/egraphGP/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import Options.Applicative as Opt hiding (Const)
import Random
import System.Random

import Debug.Trace

-- Insert random expression
-- Evaluate random subtree
-- Insert new random parent eNode
Expand Down Expand Up @@ -90,7 +92,9 @@ opDoesNotExistWith node ecId = Prelude.any (not . (`sameOpAs` node) . snd) . _pa
egraphGP :: SRMatrix -> PVector -> [Fix SRTree] -> Int -> RndEGraph (Fix SRTree, Double)
egraphGP x y terms nIter = do
insertRndExpr
forM_ [1 .. nIter] $ \i -> gpStep >> when (i `mod` 1000 == 0) (getBestExpr >>= (io . print . snd)) -- applyMergeOnlyDftl myCost >>
forM_ [1 .. nIter] $ \i -> gpStep
>> when (i `mod` 1000 == 0) (getBestExpr >>= (io . print . snd))
>> when (i `mod` 5000 == 0) (applyMergeOnlyDftl myCost)
getBestExpr
where
rndTerm = Random.randomFrom terms
Expand Down Expand Up @@ -160,19 +164,19 @@ egraphGP x y terms nIter = do
else pure Nothing
when (isJust meId) do
let eId = fromJust meId
curFit <- gets (_fitness . _info . (IM.! eId) . _eClass)
eId' <- canonical eId
curFit <- gets (_fitness . _info . (IM.! eId') . _eClass)
when (isNothing curFit) do
t <- getBest eId
t <- getBest eId'
f <- fitnessFun x y t
updateFitness f eId
updateFitness f eId'
-- io $ print ('p', showExpr t, f)

gpStep :: RndEGraph ()
gpStep = do choice <- rnd $ randomFrom [1,2,3]
gpStep = do choice <- rnd $ randomFrom [1,2,2,3,3,3]
if | choice == 1 -> insertRndExpr
| choice == 2 -> insertRndParent
| otherwise -> evalRndSubTree
-- applyMergeOnlyDftl myCost
rebuild myCost

data Args = Args
Expand Down
4 changes: 3 additions & 1 deletion src/Algorithm/EqSat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import Data.Set (Set)
import qualified Data.Set as Set
import Control.Monad ( zipWithM )

import Debug.Trace

-- | The `Scheduler` stores a map with the banned iterations of a certain rule .
-- TODO: make it more customizable.
type Scheduler a = State (IntMap Int) a
Expand Down Expand Up @@ -135,7 +137,7 @@ runEqSat costFun rules maxIter = go maxIter IntMap.empty
applySingleMergeOnlyEqSat :: Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat costFun rules =
do db <- createDB
let matchSch = matchWithScheduler db 10
let matchSch = matchWithScheduler db 0
matchAll = zipWithM matchSch [0..]
(matches, sch') = runState (matchAll rules') IntMap.empty
mapM_ (uncurry (applyMergeOnlyMatch costFun)) $ concat matches
Expand Down
22 changes: 13 additions & 9 deletions src/Algorithm/EqSat/Egraph.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
module Algorithm.EqSat.Egraph where

import Control.Lens (element, makeLenses, over, (&), (+~), (-~), (.~), (^.))
import Control.Monad (forM, forM_, when, foldM)
import Control.Monad (forM, forM_, when, foldM, void)
import Control.Monad.State
import Control.Monad.Identity
import Data.AEq (AEq ((~==)))
Expand All @@ -32,6 +32,7 @@ import Data.Set (Set)
import qualified Data.Set as Set
import System.Random (Random (randomR), StdGen)

import Debug.Trace

type EClassId = Int -- DO NOT CHANGE THIS without changing the next line! This will break the use of IntMap for speed
type ClassIdMap = IntMap
Expand Down Expand Up @@ -130,15 +131,14 @@ rebuild costFun =
-- e-graph, merge the e-classes
repair :: Monad m => CostFun -> EClassId -> ENode -> EGraphST m ()
repair costFun ecId enode =
do modify' $ over eNodeToEClass (Map.delete enode)
do -- modify' $ over eNodeToEClass (Map.delete enode)
enode' <- canonize enode
ecId' <- canonical ecId
doExist <- gets (Map.member enode' . _eNodeToEClass)
if doExist
then do ecIdCanon <- gets ((Map.! enode') . _eNodeToEClass)
_ <- merge costFun ecIdCanon ecId'
pure ()
else modify' $ over eNodeToEClass (Map.insert enode' ecId')
doExist <- gets ((Map.!? enode') . _eNodeToEClass)
case doExist of
Just ecIdCanon -> void $ merge costFun ecIdCanon ecId'
Nothing -> modify' $ over eNodeToEClass (Map.insert enode' ecId')
modify' $ over eNodeToEClass (Map.delete enode)

-- | repair the analysis of the e-class
-- considering the new added e-node
Expand Down Expand Up @@ -217,8 +217,12 @@ modifyEClass costFun ecId =
-- join data from two e-classes
joinData :: EClassData -> EClassData -> EClassData
joinData (EData c1 b1 cn1 fit1 sz1) (EData c2 b2 cn2 fit2 sz2) =
EData (min c1 c2) b (combineConsts cn1 cn2) (min fit1 fit2) (min sz1 sz2)
EData (min c1 c2) b (combineConsts cn1 cn2) (minMaybe fit1 fit2) (min sz1 sz2)
where
minMaybe Nothing x = x
minMaybe x Nothing = x
minMaybe x y = min x y

b = if c1 <= c2 then b1 else b2
combineConsts (ConstVal x) (ConstVal y)
| abs (x-y) < 1e-7 = ConstVal $ (x+y)/2
Expand Down
6 changes: 2 additions & 4 deletions src/Algorithm/EqSat/EqSatDB.hs
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ canonizeMap (subst, cv) = (,cv) . Map.fromList <$> traverse f (Map.toList subst)

applyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch costFun rule match' =
do eg <- get
let conds = getConditions rule
do let conds = getConditions rule
match <- canonizeMap match'
validHeight <- isValidHeight match
validConds <- mapM (`isValidConditions` match) conds
Expand All @@ -214,8 +213,7 @@ applyMatch costFun rule match' =

applyMergeOnlyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch costFun rule match' =
do eg <- get
let conds = getConditions rule
do let conds = getConditions rule
match <- canonizeMap match'
validHeight <- isValidHeight match
validConds <- mapM (`isValidConditions` match) conds
Expand Down
1 change: 1 addition & 0 deletions src/Algorithm/EqSat/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ rewritesFun =
, abs ("x" * "y") :=> abs "x" * abs "y" :| isConstPt "x"
, sqrt ("z" * ("x" - "y")) :=> sqrt (negate "z") * sqrt ("y" - "x")
, sqrt ("z" * ("x" + "y")) :=> sqrt "z" * sqrt ("x" + "y")
, recip (recip "x") :=> "x" :| isNotZero "x"
]

-- Rules that reduces redundant parameters
Expand Down

0 comments on commit ae8b7a1

Please sign in to comment.