Skip to content

Commit

Permalink
Use non-GADT patterns
Browse files Browse the repository at this point in the history
Export tscanl, tscanr
  • Loading branch information
vmchale committed Apr 25, 2022
1 parent 2ab05e9 commit b9c181f
Showing 1 changed file with 49 additions and 37 deletions.
86 changes: 49 additions & 37 deletions clash-prelude/src/Clash/Sized/RTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ Maintainer : QBayLogic B.V. <devops@qbaylogic.com>

{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-}

{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}

module Clash.Sized.RTree
( -- * 'RTree' data type
RTree (LR, BR)
RTree (LR, BR, LR_, BR_)
-- * Construction
, treplicate
, trepeat
-- * Accessors
, thead
, tlast
-- ** Indexing
, indexTree
, tindices
Expand All @@ -44,6 +44,8 @@ module Clash.Sized.RTree
-- ** Specialised folds
, tdfold
-- ** Prefix sums (scans)
, tscanl
, tscanr
, scanlPar
, scanrPar
-- * Conversions
Expand All @@ -63,9 +65,7 @@ import Data.Foldable (toList)
import Data.Kind (Type)
import Data.Singletons (Apply, TyFun, type (@@))
import Data.Proxy (Proxy (..))
import Data.Type.Equality ((:~:) (..))
import GHC.TypeLits (KnownNat, Nat, type (+), type (^), type (*),
sameNat)
import GHC.TypeLits (KnownNat, Nat, type (+), type (^), type (*))
import Language.Haskell.TH.Syntax (Lift(..))
#if MIN_VERSION_template_haskell(2,16,0)
import Language.Haskell.TH.Compat
Expand All @@ -75,7 +75,7 @@ import Test.QuickCheck (Arbitrary (..), CoArbitrary (..))

import Clash.Annotations.Primitive (hasBlackBox)
import Clash.Class.BitPack (BitPack (..), packXWith)
import Clash.Promoted.Nat (SNat (..), SNatLE (..), UNat (..), compareSNat,
import Clash.Promoted.Nat (SNat (..), UNat (..),
pow2SNat, snatToNum, subSNat, toUNat)
import Clash.Promoted.Nat.Literals (d1)
import Clash.Sized.Index (Index)
Expand Down Expand Up @@ -553,13 +553,13 @@ lazyT = tzipWith (flip const) (trepeat ())
-- The operator must be associative.
--
-- <<doc/scanlPar.svg>>
scanlPar ::
scanlPar ::
KnownNat n =>
-- | Must be associative
(a -> a -> a) ->
Vec (2^n) a ->
Vec (2^n) a
scanlPar op v = scanlInductiveRTree op (v2t v)
scanlPar op = t2v . tscanl op . v2t
{-# INLINE scanlPar #-}

-- | Low-depth (right) scan
Expand All @@ -568,47 +568,59 @@ scanlPar op v = scanlInductiveRTree op (v2t v)
-- 10 :> 9 :> 7 :> 4 :> Nil
--
-- The operator must be associative.
scanrPar ::
scanrPar ::
KnownNat n =>
-- | Must be associative
(a -> a -> a) ->
Vec (2^n) a ->
Vec (2^n) a
scanrPar op v = scanrInductiveRTree op (v2t v)
scanrPar op = t2v . tscanr op . v2t
{-# INLINE scanrPar #-}

scanlInductiveRTree ::
-- |
--
-- >>> thead $ BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4))
-- 1
thead :: RTree n a -> a
thead (LR_ x) = x
thead (BR_ x _) = thead x

-- |
--
-- >>> tlast $ BR (BR (LR 1) (LR 2)) (BR (LR 3) (LR 4))
-- 4
tlast :: RTree n a -> a
tlast (LR_ x) = x
tlast (BR_ _ y) = tlast y

tscanl ::
forall a n.
KnownNat n =>
(a -> a -> a) ->
RTree n a ->
Vec (2^n) a
scanlInductiveRTree op tr =
-- I have to use sameNat and compareSNat both; the type checker cannot infer
-- that n <= 0 means n ~ 0.
case (sameNat (Proxy @n) (Proxy @0), compareSNat (SNat @n) (SNat @0), tr) of
(Just Refl, _, LR x) -> x :> Nil
(_, SNatGT, BR x y) ->
let
x' = scanlInductiveRTree op x
y' = scanlInductiveRTree op y
l = x' !! (length x'-1) -- 'last' doesn't work here
in x' ++ map (l `op`) y'

scanrInductiveRTree ::
RTree n a
tscanl op tr =
case tr of
(LR_ x) -> LR x -- :> Nil
(BR_ x y) ->
let
x' = tscanl op x
y' = tscanl op y
l = tlast x'
in BR x' (fmap (l `op`) y')

tscanr ::
forall a n.
KnownNat n =>
(a -> a -> a) ->
RTree n a ->
Vec (2^n) a
scanrInductiveRTree op tr =
-- I have to use sameNat and compareSNat both; the type checker cannot infer
-- that n <= 0 means n ~ 0.
case (sameNat (Proxy @n) (Proxy @0), compareSNat (SNat @n) (SNat @0), tr) of
(Just Refl, _, LR x) -> x :> Nil
(_, SNatGT, BR x y) ->
RTree n a
tscanr op tr =
case tr of
(LR_ x) -> LR x
(BR_ x y) ->
let
x' = scanrInductiveRTree op x
y' = scanrInductiveRTree op y
l = y' !! (0::Int) -- `head` doesn't work here
in map (l `op`) x' ++ y'
x' = tscanr op x
y' = tscanr op y
l = thead y'
in BR (fmap (l `op`) x') y'

0 comments on commit b9c181f

Please sign in to comment.