Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

foldTree does not optimize well #946

Open
meooow25 opened this issue Apr 30, 2023 · 2 comments
Open

foldTree does not optimize well #946

meooow25 opened this issue Apr 30, 2023 · 2 comments

Comments

@meooow25
Copy link
Contributor

We have in Data.Tree

foldTree :: (a -> [b] -> b) -> Tree a -> b 

Unfortunately foldTree does not optimize as well it could when b is a function.

An an example, consider that we want to calculate the sum of depths of nodes in a tree.
We can write a recursive function manually:

depthSum_rec :: Tree a -> Int
depthSum_rec t = go t 0 0 where
    go (Node _ ts) depth acc = foldl' (\acc' t' -> go t' (depth+1) acc') (acc + depth) ts

--    depthSum_rec: OK (0.18s)
--      5.18 ms ± 508 μs

Now let's use foldTree:

depthSum_foldTree :: Tree a -> Int
depthSum_foldTree t = foldTree f t 0 0 where
    f _ ks depth acc = foldl' (\acc' k -> k (depth+1) acc') (acc + depth) ks

--    depthSum_foldTree: OK (0.34s)
--      43.6 ms ± 3.2 ms

That's a lot worse! The problem is that the list of partially applied functions [b] is manifested, see GHC#23319. According to SPJ this can't be easily improved.


Consider a different fold function which also folds over the [b] without creating it:

foldTree2 :: (a -> b -> c) -> (c -> b -> b) -> b -> Tree a -> c
foldTree2 f c z = go where go (Node x ts) = f x (foldr (c . go) z ts)

Now we can write:

depthSum_foldTree2 :: Tree a -> Int
depthSum_foldTree2 t = foldTree2 f f' (const id) t 0 0 where
    f _ k depth acc = k depth (acc + depth)
    f' k1 k2 depth acc = k2 depth (k1 (depth+1) acc)

--    depthSum_foldTree2: OK (0.23s)
--      5.16 ms ± 376 μs

As good as depthSum_rec! Could we have foldTree2 (perhaps with a better name) in Data.Tree?


The benchmark setup, for completeness
import Data.List
import Data.Tree
import Test.Tasty.Bench

main :: IO ()
main = defaultMain
    [ env (pure binTree) $ \t -> bgroup ""
        [ bench "depthSum_rec" $ whnf depthSum_rec t
        , bench "depthSum_foldTree" $ whnf depthSum_foldTree t
        , bench "depthSum_foldTree2" $ whnf depthSum_foldTree2 t
        ]
    ]

binTree :: Tree Int
binTree = unfoldTree (\x -> (x, takeWhile (<1000000) [2*x + 1, 2*x + 2])) 1
@treeowl
Copy link
Contributor

treeowl commented Apr 30, 2023

The type of the function gives me no clue what it does. That makes me a bit suspicious. Can you write documentation that makes it easy for people to think about? What benefit does this have over doing the folding by hand?

@meooow25
Copy link
Contributor Author

The type of the function gives me no clue what it does.

It is just the replacement of all the constructors involved in a Tree. So foldTree2 Node (:) [] = id.

What benefit does this have over doing the folding by hand?

It lets us avoid writing a recursive function, which often ends up shorter or simpler. Another benefit is that it could participate in fold/build fusion. I have been thinking about this a bit, but perhaps it deserves a separate issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants