Skip to content

Commit

Permalink
Speed up equality; remove expensive invariant
Browse files Browse the repository at this point in the history
We used to define `==` to test equality in *priority* order. This was
expensive in and of itself, since it takes logarithmic time to produce
each entry. But the situation was worse than that: since priorities are
not unique, we introduced an extra invariant to make sure that keys with
the same priority are extracted in key order. This requires additional
comparisons in various performance-critical functions.

* Remove the extra invariant.
* Remove the extra code to preserve it.
* Change `(==)` to work in key order instead of priority order.
  Keys are unique, so this is very natural.
* Remove a bunch of now-redundant constraints.
* Document the nondeterminism.
  • Loading branch information
treeowl committed May 16, 2023
1 parent ea54bbc commit 2d971b8
Showing 1 changed file with 42 additions and 56 deletions.
98 changes: 42 additions & 56 deletions src/Data/OrdPSQ/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ module Data.OrdPSQ.Internal

import Control.DeepSeq (NFData (rnf))
import Data.Foldable (Foldable (foldr))
import Data.Function (on)
import qualified Data.List as List
import Data.Maybe (isJust)
import Data.Traversable
Expand Down Expand Up @@ -105,13 +106,8 @@ instance (NFData k, NFData p, NFData v) => NFData (OrdPSQ k p v) where
rnf Void = ()
rnf (Winner e t m) = rnf e `seq` rnf m `seq` rnf t

instance (Ord k, Ord p, Eq v) => Eq (OrdPSQ k p v) where
x == y = case (minView x, minView y) of
(Nothing , Nothing ) -> True
(Just (xk, xp, xv, x'), (Just (yk, yp, yv, y'))) ->
xk == yk && xp == yp && xv == yv && x' == y'
(Just _ , Nothing ) -> False
(Nothing , Just _ ) -> False
instance (Eq k, Eq p, Eq v) => Eq (OrdPSQ k p v) where
(==) = (==) `on` toAscList

type Size = Int

Expand Down Expand Up @@ -155,7 +151,7 @@ member k = isJust . lookup k

-- | /O(log n)/ The priority and value of a given key, or 'Nothing' if the key
-- is not bound.
lookup :: (Ord k) => k -> OrdPSQ k p v -> Maybe (p, v)
lookup :: Ord k => k -> OrdPSQ k p v -> Maybe (p, v)
lookup k = go
where
go t = case tourView t of
Expand Down Expand Up @@ -233,12 +229,14 @@ delete k = go
| k <= m -> go (Winner e' tl m) `play` (Winner e tr m')
| otherwise -> (Winner e' tl m) `play` go (Winner e tr m')

-- | /O(log n)/ Delete the binding with the least priority, and return the
-- rest of the queue stripped of that binding. In case the queue is empty, the
-- empty queue is returned again.
-- | /O(log n)/ Delete one of the bindings with the least priority, and return
-- the rest of the queue stripped of that binding. In case the queue is empty,
-- the empty queue is returned again. If multiple bindings have the least
-- priority, then which one is extracted is nondeterministic, in the sense that
-- different choices may be made for queues that compare `==`.
{-# INLINE deleteMin #-}
deleteMin
:: (Ord k, Ord p) => OrdPSQ k p v -> OrdPSQ k p v
:: Ord p => OrdPSQ k p v -> OrdPSQ k p v
deleteMin t = case minView t of
Nothing -> t
Just (_, _, _, t') -> t'
Expand Down Expand Up @@ -345,22 +343,24 @@ deleteView k psq = case psq of
| k <= m -> fmap (\(p,v,q) -> (p, v, q `play` (Winner e tr m'))) (deleteView k (Winner e' tl m))
| otherwise -> fmap (\(p,v,q) -> (p, v, (Winner e' tl m) `play` q )) (deleteView k (Winner e tr m'))

-- | /O(log n)/ Retrieve the binding with the least priority, and the
-- rest of the queue stripped of that binding.
-- | /O(log n)/ Retrieve one of the bindings with the least priority, and the
-- rest of the queue stripped of that binding. If multiple bindings have the
-- least priority, then which one is extracted is nondeterministic, in the
-- sense that different choices may be made for queues that compare `==`.
{-# INLINABLE minView #-}
minView :: (Ord k, Ord p) => OrdPSQ k p v -> Maybe (k, p, v, OrdPSQ k p v)
minView :: Ord p => OrdPSQ k p v -> Maybe (k, p, v, OrdPSQ k p v)
minView Void = Nothing
minView (Winner (E k p v) t m) = Just (k, p, v, secondBest t m)

secondBest :: (Ord k, Ord p) => LTree k p v -> k -> OrdPSQ k p v
secondBest :: Ord p => LTree k p v -> k -> OrdPSQ k p v
secondBest Start _ = Void
secondBest (LLoser _ e tl m tr) m' = Winner e tl m `play` secondBest tr m'
secondBest (RLoser _ e tl m tr) m' = secondBest tl m `play` Winner e tr m'

-- | Return a list of elements ordered by key whose priorities are at most @pt@,
-- and the rest of the queue stripped of these elements. The returned list of
-- elements can be in any order: no guarantees there.
atMostView :: (Ord k, Ord p) => p -> OrdPSQ k p v -> ([(k, p, v)], OrdPSQ k p v)
atMostView :: Ord p => p -> OrdPSQ k p v -> ([(k, p, v)], OrdPSQ k p v)
atMostView pt = go []
where
go acc t@(Winner (E _ p _) _ _)
Expand Down Expand Up @@ -462,19 +462,12 @@ tourView (Winner e (LLoser _ e' tl m tr) m') =
-- the two with the precondition that the keys in the first tree are
-- strictly smaller than the keys in the second tree.
{-# INLINABLE play #-}
play :: (Ord p, Ord k) => OrdPSQ k p v -> OrdPSQ k p v -> OrdPSQ k p v
play :: Ord p => OrdPSQ k p v -> OrdPSQ k p v -> OrdPSQ k p v
Void `play` t' = t'
t `play` Void = t
Winner e@(E k p v) t m `play` Winner e'@(E k' p' v') t' m'
| (p, k) `beats` (p', k') = Winner e (rbalance k' p' v' t m t') m'
| otherwise = Winner e' (lbalance k p v t m t') m'

-- | When priorities are equal, the tree with the lowest key wins. This is
-- important to have a deterministic `==`, which requires on `minView` pulling
-- out the elements in the right order.
beats :: (Ord p, Ord k) => (p, k) -> (p, k) -> Bool
beats (p, !k) (p', !k') = p < p' || (p == p' && k < k')
{-# INLINE beats #-}
| p <= p' = Winner e (rbalance k' p' v' t m t') m'
| otherwise = Winner e' (lbalance k p v t m t') m'


--------------------------------------------------------------------------------
Expand Down Expand Up @@ -509,8 +502,7 @@ lloser k p v tl m tr = LLoser (1 + size' tl + size' tr) (E k p v) tl m tr
rloser k p v tl m tr = RLoser (1 + size' tl + size' tr) (E k p v) tl m tr

lbalance, rbalance
:: (Ord k, Ord p)
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
:: Ord p => k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
lbalance k p v l m r
| size' l + size' r < 2 = lloser k p v l m r
| size' r > omega * size' l = lbalanceLeft k p v l m r
Expand All @@ -524,38 +516,33 @@ rbalance k p v l m r
| otherwise = rloser k p v l m r

lbalanceLeft
:: (Ord k, Ord p)
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
:: Ord p => k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
lbalanceLeft k p v l m r
| size' (left r) < size' (right r) = lsingleLeft k p v l m r
| otherwise = ldoubleLeft k p v l m r

lbalanceRight
:: (Ord k, Ord p)
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
:: Ord p => k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
lbalanceRight k p v l m r
| size' (left l) > size' (right l) = lsingleRight k p v l m r
| otherwise = ldoubleRight k p v l m r

rbalanceLeft
:: (Ord k, Ord p)
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
:: Ord p => k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
rbalanceLeft k p v l m r
| size' (left r) < size' (right r) = rsingleLeft k p v l m r
| otherwise = rdoubleLeft k p v l m r

rbalanceRight
:: (Ord k, Ord p)
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
:: Ord p => k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
rbalanceRight k p v l m r
| size' (left l) > size' (right l) = rsingleRight k p v l m r
| otherwise = rdoubleRight k p v l m r

lsingleLeft
:: (Ord k, Ord p)
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
:: Ord p => k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
lsingleLeft k1 p1 v1 t1 m1 (LLoser _ (E k2 p2 v2) t2 m2 t3)
| (p1, k1) `beats` (p2, k2) =
| p1 <= p2 =
lloser k1 p1 v1 (rloser k2 p2 v2 t1 m1 t2) m2 t3
| otherwise =
lloser k2 p2 v2 (lloser k1 p1 v1 t1 m1 t2) m2 t3
Expand All @@ -578,19 +565,18 @@ lsingleRight k1 p1 v1 (RLoser _ (E k2 p2 v2) t1 m1 t2) m2 t3 =
lsingleRight _ _ _ _ _ _ = moduleError "lsingleRight" "malformed tree"

rsingleRight
:: (Ord k, Ord p)
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
:: Ord p => k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
rsingleRight k1 p1 v1 (LLoser _ (E k2 p2 v2) t1 m1 t2) m2 t3 =
lloser k2 p2 v2 t1 m1 (rloser k1 p1 v1 t2 m2 t3)
rsingleRight k1 p1 v1 (RLoser _ (E k2 p2 v2) t1 m1 t2) m2 t3
| (p1, k1) `beats` (p2, k2) =
| p1 <= p2 =
rloser k1 p1 v1 t1 m1 (lloser k2 p2 v2 t2 m2 t3)
| otherwise =
rloser k2 p2 v2 t1 m1 (rloser k1 p1 v1 t2 m2 t3)
rsingleRight _ _ _ _ _ _ = moduleError "rsingleRight" "malformed tree"

ldoubleLeft
:: (Ord k, Ord p)
:: Ord p
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
ldoubleLeft k1 p1 v1 t1 m1 (LLoser _ (E k2 p2 v2) t2 m2 t3) =
lsingleLeft k1 p1 v1 t1 m1 (lsingleRight k2 p2 v2 t2 m2 t3)
Expand All @@ -599,7 +585,7 @@ ldoubleLeft k1 p1 v1 t1 m1 (RLoser _ (E k2 p2 v2) t2 m2 t3) =
ldoubleLeft _ _ _ _ _ _ = moduleError "ldoubleLeft" "malformed tree"

ldoubleRight
:: (Ord k, Ord p)
:: Ord p
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
ldoubleRight k1 p1 v1 (LLoser _ (E k2 p2 v2) t1 m1 t2) m2 t3 =
lsingleRight k1 p1 v1 (lsingleLeft k2 p2 v2 t1 m1 t2) m2 t3
Expand All @@ -608,7 +594,7 @@ ldoubleRight k1 p1 v1 (RLoser _ (E k2 p2 v2) t1 m1 t2) m2 t3 =
ldoubleRight _ _ _ _ _ _ = moduleError "ldoubleRight" "malformed tree"

rdoubleLeft
:: (Ord k, Ord p)
:: Ord p
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
rdoubleLeft k1 p1 v1 t1 m1 (LLoser _ (E k2 p2 v2) t2 m2 t3) =
rsingleLeft k1 p1 v1 t1 m1 (lsingleRight k2 p2 v2 t2 m2 t3)
Expand All @@ -617,7 +603,7 @@ rdoubleLeft k1 p1 v1 t1 m1 (RLoser _ (E k2 p2 v2) t2 m2 t3) =
rdoubleLeft _ _ _ _ _ _ = moduleError "rdoubleLeft" "malformed tree"

rdoubleRight
:: (Ord k, Ord p)
:: Ord p
=> k -> p -> v -> LTree k p v -> k -> LTree k p v -> LTree k p v
rdoubleRight k1 p1 v1 (LLoser _ (E k2 p2 v2) t1 m1 t2) m2 t3 =
rsingleRight k1 p1 v1 (lsingleLeft k2 p2 v2 t1 m1 t2) m2 t3
Expand All @@ -642,19 +628,19 @@ valid t =
hasDuplicateKeys :: Ord k => OrdPSQ k p v -> Bool
hasDuplicateKeys = any (> 1) . List.map length . List.group . List.sort . keys

hasMinHeapProperty :: forall k p v. (Ord k, Ord p) => OrdPSQ k p v -> Bool
hasMinHeapProperty :: forall k p v. Ord p => OrdPSQ k p v -> Bool
hasMinHeapProperty Void = True
hasMinHeapProperty (Winner (E k0 p0 _) t0 _) = go k0 p0 t0
hasMinHeapProperty (Winner (E _k0 p0 _) t0 _) = go p0 t0
where
go :: k -> p -> LTree k p v -> Bool
go _ _ Start = True
go k p (LLoser _ (E k' p' _) l _ r) =
(p, k) < (p', k') && go k' p' l && go k p r
go k p (RLoser _ (E k' p' _) l _ r) =
(p, k) < (p', k') && go k p l && go k' p' r
go :: p -> LTree k p v -> Bool
go _ Start = True
go p (LLoser _ (E _k' p' _) l _ r) =
p <= p' && go p' l && go p r
go p (RLoser _ (E _k' p' _) l _ r) =
p <= p' && go p l && go p' r

hasBinarySearchTreeProperty
:: forall k p v. (Ord k, Ord p) => OrdPSQ k p v -> Bool
:: forall k p v. Ord k => OrdPSQ k p v -> Bool
hasBinarySearchTreeProperty t = case tourView t of
Null -> True
Single _ -> True
Expand Down

0 comments on commit 2d971b8

Please sign in to comment.