diff --git a/clash-prelude/src/Clash/Sized/RTree.hs b/clash-prelude/src/Clash/Sized/RTree.hs index 3584ebbf1c..b2472d2d3d 100644 --- a/clash-prelude/src/Clash/Sized/RTree.hs +++ b/clash-prelude/src/Clash/Sized/RTree.hs @@ -17,15 +17,15 @@ Maintainer : QBayLogic B.V. {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-} -{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} - module Clash.Sized.RTree ( -- * 'RTree' data type - RTree (LR, BR) + RTree (LR, BR, LR_, BR_) -- * Construction , treplicate , trepeat -- * Accessors + , thead + , tlast -- ** Indexing , indexTree , tindices @@ -44,6 +44,8 @@ module Clash.Sized.RTree -- ** Specialised folds , tdfold -- ** Prefix sums (scans) + , tscanl + , tscanr , scanlPar , scanrPar -- * Conversions @@ -63,9 +65,7 @@ import Data.Foldable (toList) import Data.Kind (Type) import Data.Singletons (Apply, TyFun, type (@@)) import Data.Proxy (Proxy (..)) -import Data.Type.Equality ((:~:) (..)) -import GHC.TypeLits (KnownNat, Nat, type (+), type (^), type (*), - sameNat) +import GHC.TypeLits (KnownNat, Nat, type (+), type (^), type (*)) import Language.Haskell.TH.Syntax (Lift(..)) #if MIN_VERSION_template_haskell(2,16,0) import Language.Haskell.TH.Compat @@ -75,7 +75,7 @@ import Test.QuickCheck (Arbitrary (..), CoArbitrary (..)) import Clash.Annotations.Primitive (hasBlackBox) import Clash.Class.BitPack (BitPack (..), packXWith) -import Clash.Promoted.Nat (SNat (..), SNatLE (..), UNat (..), compareSNat, +import Clash.Promoted.Nat (SNat (..), UNat (..), pow2SNat, snatToNum, subSNat, toUNat) import Clash.Promoted.Nat.Literals (d1) import Clash.Sized.Index (Index) @@ -553,13 +553,13 @@ lazyT = tzipWith (flip const) (trepeat ()) -- The operator must be associative. -- -- <> -scanlPar :: +scanlPar :: KnownNat n => -- | Must be associative (a -> a -> a) -> Vec (2^n) a -> Vec (2^n) a -scanlPar op v = scanlInductiveRTree op (v2t v) +scanlPar op = t2v . tscanl op . v2t {-# INLINE scanlPar #-} -- | Low-depth (right) scan @@ -568,47 +568,59 @@ scanlPar op v = scanlInductiveRTree op (v2t v) -- 10 :> 9 :> 7 :> 4 :> Nil -- -- The operator must be associative. -scanrPar :: +scanrPar :: KnownNat n => -- | Must be associative (a -> a -> a) -> Vec (2^n) a -> Vec (2^n) a -scanrPar op v = scanrInductiveRTree op (v2t v) +scanrPar op = t2v . tscanr op . v2t {-# INLINE scanrPar #-} -scanlInductiveRTree :: +-- | +-- +-- >>> thead $ BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4)) +-- 1 +thead :: RTree n a -> a +thead (LR_ x) = x +thead (BR_ x _) = thead x + +-- | +-- +-- >>> tlast $ BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4)) +-- 4 +tlast :: RTree n a -> a +tlast (LR_ x) = x +tlast (BR_ _ y) = tlast y + +tscanl :: forall a n. KnownNat n => (a -> a -> a) -> RTree n a -> - Vec (2^n) a -scanlInductiveRTree op tr = - -- I have to use sameNat and compareSNat both; the type checker cannot infer - -- that n <= 0 means n ~ 0. - case (sameNat (Proxy @n) (Proxy @0), compareSNat (SNat @n) (SNat @0), tr) of - (Just Refl, _, LR x) -> x :> Nil - (_, SNatGT, BR x y) -> - let - x' = scanlInductiveRTree op x - y' = scanlInductiveRTree op y - l = x' !! (length x'-1) -- 'last' doesn't work here - in x' ++ map (l `op`) y' - -scanrInductiveRTree :: + RTree n a +tscanl op tr = + case tr of + (LR_ x) -> LR x -- :> Nil + (BR_ x y) -> + let + x' = tscanl op x + y' = tscanl op y + l = tlast x' + in BR x' (fmap (l `op`) y') + +tscanr :: forall a n. KnownNat n => (a -> a -> a) -> RTree n a -> - Vec (2^n) a -scanrInductiveRTree op tr = - -- I have to use sameNat and compareSNat both; the type checker cannot infer - -- that n <= 0 means n ~ 0. - case (sameNat (Proxy @n) (Proxy @0), compareSNat (SNat @n) (SNat @0), tr) of - (Just Refl, _, LR x) -> x :> Nil - (_, SNatGT, BR x y) -> + RTree n a +tscanr op tr = + case tr of + (LR_ x) -> LR x + (BR_ x y) -> let - x' = scanrInductiveRTree op x - y' = scanrInductiveRTree op y - l = y' !! (0::Int) -- `head` doesn't work here - in map (l `op`) x' ++ y' + x' = tscanr op x + y' = tscanr op y + l = thead y' + in BR (fmap (l `op`) x') y'