Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ST instead of polymorphic m in Vector/MVector type classes #335

Merged
merged 4 commits into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Data/Vector/Fusion/Bundle/Monadic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fromStream :: Monad m => Stream m a -> Size -> Bundle m v a
fromStream (Stream step t) sz = Bundle (Stream step t) (Stream step' t) Nothing sz
where
step' s = do r <- step s
return $ fmap (\x -> Chunk 1 (\v -> M.basicUnsafeWrite v 0 x)) r
return $ fmap (\x -> Chunk 1 (\v -> stToPrim $ M.basicUnsafeWrite v 0 x)) r

chunks :: Bundle m v a -> Stream m (Chunk v a)
{-# INLINE chunks #-}
Expand Down Expand Up @@ -185,7 +185,7 @@ singleton x = fromStream (S.singleton x) (Exact 1)
replicate :: Monad m => Int -> a -> Bundle m v a
{-# INLINE_FUSED replicate #-}
replicate n x = Bundle (S.replicate n x)
(S.singleton $ Chunk len (\v -> M.basicSet v x))
(S.singleton $ Chunk len (\v -> stToPrim $ M.basicSet v x))
Nothing
(Exact len)
where
Expand Down Expand Up @@ -1086,7 +1086,7 @@ fromVector v = v `seq` n `seq` Bundle (Stream step 0)


{-# INLINE vstep #-}
vstep True = return (Yield (Chunk (basicLength v) (\mv -> basicUnsafeCopy mv v)) False)
vstep True = return (Yield (Chunk (basicLength v) (\mv -> stToPrim $ basicUnsafeCopy mv v)) False)
vstep False = return Done

fromVectors :: forall m v a. (Monad m, Vector v a) => [v a] -> Bundle m v a
Expand All @@ -1112,7 +1112,7 @@ fromVectors us = Bundle (Stream pstep (Left us))
vstep (v:vs) = return $ Yield (Chunk (basicLength v)
(\mv -> INTERNAL_CHECK(check) "concatVectors" "length mismatch"
(M.basicLength mv == basicLength v)
$ basicUnsafeCopy mv v)) vs
$ stToPrim $ basicUnsafeCopy mv v)) vs


concatVectors :: (Monad m, Vector v a) => Bundle m u (v a) -> Bundle m v a
Expand Down Expand Up @@ -1142,7 +1142,7 @@ concatVectors Bundle{sElems = Stream step t}
Yield v s' -> return (Yield (Chunk (basicLength v)
(\mv -> INTERNAL_CHECK(check) "concatVectors" "length mismatch"
(M.basicLength mv == basicLength v)
$ basicUnsafeCopy mv v)) s')
$ stToPrim $ basicUnsafeCopy mv v)) s')
Skip s' -> return (Skip s')
Done -> return Done

Expand Down
6 changes: 5 additions & 1 deletion Data/Vector/Fusion/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
--

module Data.Vector.Fusion.Util (
Id(..), Box(..),
Id(..), Box(..), liftBox,

delay_inline, delayed_min
) where
Expand Down Expand Up @@ -45,6 +45,10 @@ instance Monad Box where
return = pure
Box x >>= f = f x

liftBox :: Monad m => Box a -> m a
liftBox (Box a) = return a
{-# INLINE liftBox #-}

-- | Delay inlining a function until late in the game (simplifier phase 0).
delay_inline :: (a -> b) -> a -> b
{-# INLINE [0] delay_inline #-}
Expand Down
12 changes: 7 additions & 5 deletions Data/Vector/Generic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ infixl 9 !
(!) :: Vector v a => v a -> Int -> a
{-# INLINE_FUSED (!) #-}
(!) v i = BOUNDS_CHECK(checkIndex) "(!)" i (length v)
$ unId (basicUnsafeIndexM v i)
$ unBox (basicUnsafeIndexM v i)

infixl 9 !?
-- | O(1) Safe indexing
Expand All @@ -258,7 +258,7 @@ last v = v ! (length v - 1)
unsafeIndex :: Vector v a => v a -> Int -> a
{-# INLINE_FUSED unsafeIndex #-}
unsafeIndex v i = UNSAFE_CHECK(checkIndex) "unsafeIndex" i (length v)
$ unId (basicUnsafeIndexM v i)
$ unBox (basicUnsafeIndexM v i)

-- | /O(1)/ First element without checking if the vector is empty
unsafeHead :: Vector v a => v a -> a
Expand Down Expand Up @@ -320,6 +320,7 @@ unsafeLast v = unsafeIndex v (length v - 1)
indexM :: (Vector v a, Monad m) => v a -> Int -> m a
{-# INLINE_FUSED indexM #-}
indexM v i = BOUNDS_CHECK(checkIndex) "indexM" i (length v)
$ liftBox
$ basicUnsafeIndexM v i

-- | /O(1)/ First element of a vector in a monad. See 'indexM' for an
Expand All @@ -339,6 +340,7 @@ lastM v = indexM v (length v - 1)
unsafeIndexM :: (Vector v a, Monad m) => v a -> Int -> m a
{-# INLINE_FUSED unsafeIndexM #-}
unsafeIndexM v i = UNSAFE_CHECK(checkIndex) "unsafeIndexM" i (length v)
$ liftBox
$ basicUnsafeIndexM v i

-- | /O(1)/ First element in a monad without checking for empty vectors.
Expand Down Expand Up @@ -2002,7 +2004,7 @@ convert = unstream . Bundle.reVector . stream
unsafeFreeze
:: (PrimMonad m, Vector v a) => Mutable v (PrimState m) a -> m (v a)
{-# INLINE unsafeFreeze #-}
unsafeFreeze = basicUnsafeFreeze
unsafeFreeze = stToPrim . basicUnsafeFreeze

-- | /O(n)/ Yield an immutable copy of the mutable vector.
freeze :: (PrimMonad m, Vector v a) => Mutable v (PrimState m) a -> m (v a)
Expand All @@ -2013,7 +2015,7 @@ freeze mv = unsafeFreeze =<< M.clone mv
-- copying. The immutable vector may not be used after this operation.
unsafeThaw :: (PrimMonad m, Vector v a) => v a -> m (Mutable v (PrimState m) a)
{-# INLINE_FUSED unsafeThaw #-}
unsafeThaw = basicUnsafeThaw
unsafeThaw = stToPrim . basicUnsafeThaw

-- | /O(n)/ Yield a mutable copy of the immutable vector.
thaw :: (PrimMonad m, Vector v a) => v a -> m (Mutable v (PrimState m) a)
Expand Down Expand Up @@ -2072,7 +2074,7 @@ unsafeCopy
{-# INLINE unsafeCopy #-}
unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
(M.length dst == basicLength src)
$ (dst `seq` src `seq` basicUnsafeCopy dst src)
$ (dst `seq` src `seq` stToPrim (basicUnsafeCopy dst src))

-- Conversions to/from Bundles
-- ---------------------------
Expand Down
12 changes: 7 additions & 5 deletions Data/Vector/Generic/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ module Data.Vector.Generic.Base (

import Data.Vector.Generic.Mutable.Base ( MVector )
import qualified Data.Vector.Generic.Mutable.Base as M
import Data.Vector.Fusion.Util (Box(..), liftBox)

import Control.Monad.ST
import Control.Monad.Primitive

-- | @Mutable v s a@ is the mutable version of the pure vector type @v a@ with
Expand Down Expand Up @@ -59,13 +61,13 @@ class MVector (Mutable v) a => Vector v a where
-- Unsafely convert a mutable vector to its immutable version
-- without copying. The mutable vector may not be used after
-- this operation.
basicUnsafeFreeze :: PrimMonad m => Mutable v (PrimState m) a -> m (v a)
basicUnsafeFreeze :: Mutable v s a -> ST s (v a)

-- | /Assumed complexity: O(1)/
--
-- Unsafely convert an immutable vector to its mutable version without
-- copying. The immutable vector may not be used after this operation.
basicUnsafeThaw :: PrimMonad m => v a -> m (Mutable v (PrimState m) a)
basicUnsafeThaw :: v a -> ST s (Mutable v s a)

-- | /Assumed complexity: O(1)/
--
Expand Down Expand Up @@ -105,7 +107,7 @@ class MVector (Mutable v) a => Vector v a where
-- which does not have this problem because indexing (but not the returned
-- element!) is evaluated immediately.
--
basicUnsafeIndexM :: Monad m => v a -> Int -> m a
basicUnsafeIndexM :: v a -> Int -> Box a

-- | /Assumed complexity: O(n)/
--
Expand All @@ -117,15 +119,15 @@ class MVector (Mutable v) a => Vector v a where
--
-- Default definition: copying basic on 'basicUnsafeIndexM' and
-- 'basicUnsafeWrite'.
basicUnsafeCopy :: PrimMonad m => Mutable v (PrimState m) a -> v a -> m ()
basicUnsafeCopy :: Mutable v s a -> v a -> ST s ()

{-# INLINE basicUnsafeCopy #-}
basicUnsafeCopy !dst !src = do_copy 0
where
!n = basicLength src

do_copy i | i < n = do
x <- basicUnsafeIndexM src i
x <- liftBox $ basicUnsafeIndexM src i
M.basicUnsafeWrite dst i x
do_copy (i+1)
| otherwise = return ()
Expand Down
35 changes: 22 additions & 13 deletions Data/Vector/Generic/Mutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ import qualified Data.Vector.Fusion.Stream.Monadic as Stream
import Data.Vector.Fusion.Bundle.Size
import Data.Vector.Fusion.Util ( delay_inline )

import Control.Monad.Primitive ( PrimMonad(..), RealWorld )
import Control.Monad.Primitive ( PrimMonad(..), RealWorld, stToPrim )

import Prelude hiding ( length, null, replicate, reverse, map, read,
take, drop, splitAt, init, tail )
Expand Down Expand Up @@ -591,6 +591,7 @@ overlaps = basicOverlaps
new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
{-# INLINE new #-}
new n = BOUNDS_CHECK(checkLength) "new" n
$ stToPrim
$ unsafeNew n >>= \v -> basicInitialize v >> return v

-- | Create a mutable vector of the given length. The vector content
Expand All @@ -604,13 +605,14 @@ new n = BOUNDS_CHECK(checkLength) "new" n
unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
{-# INLINE unsafeNew #-}
unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
$ stToPrim
$ basicUnsafeNew n

-- | Create a mutable vector of the given length (0 if the length is negative)
-- and fill it with an initial value.
replicate :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
{-# INLINE replicate #-}
replicate n x = basicUnsafeReplicate (delay_inline max 0 n) x
replicate n x = stToPrim $ basicUnsafeReplicate (delay_inline max 0 n) x

-- | Create a mutable vector of the given length (0 if the length is negative)
-- and fill it with values produced by repeatedly executing the monadic action.
Expand All @@ -635,6 +637,7 @@ grow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE grow #-}
grow v by = BOUNDS_CHECK(checkLength) "grow" by
$ stToPrim
$ do vnew <- unsafeGrow v by
basicInitialize $ basicUnsafeSlice (length v) by vnew
return vnew
Expand All @@ -643,6 +646,7 @@ growFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE growFront #-}
growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
$ stToPrim
$ do vnew <- unsafeGrowFront v by
basicInitialize $ basicUnsafeSlice 0 by vnew
return vnew
Expand All @@ -654,16 +658,17 @@ enlarge_delta v = max (length v) 1
enlarge :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE enlarge #-}
enlarge v = do vnew <- unsafeGrow v by
basicInitialize $ basicUnsafeSlice (length v) by vnew
return vnew
enlarge v = stToPrim $ do
vnew <- unsafeGrow v by
basicInitialize $ basicUnsafeSlice (length v) by vnew
return vnew
where
by = enlarge_delta v

enlargeFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a, Int)
{-# INLINE enlargeFront #-}
enlargeFront v = do
enlargeFront v = stToPrim $ do
v' <- unsafeGrowFront v by
basicInitialize $ basicUnsafeSlice 0 by v'
return (v', by)
Expand All @@ -676,13 +681,14 @@ unsafeGrow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE unsafeGrow #-}
unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
$ stToPrim
$ basicUnsafeGrow v n

unsafeGrowFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
{-# INLINE unsafeGrowFront #-}
unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
$ do
$ stToPrim $ do
let n = length v
v' <- basicUnsafeNew (by+n)
basicUnsafeCopy (basicUnsafeSlice by n v') v
Expand All @@ -695,7 +701,7 @@ unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
-- references to external objects. This is usually a noop for unboxed vectors.
clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
{-# INLINE clear #-}
clear = basicClear
clear = stToPrim . basicClear

-- Accessing individual elements
-- -----------------------------
Expand Down Expand Up @@ -735,19 +741,22 @@ exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
{-# INLINE unsafeRead #-}
unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
$ stToPrim
$ basicUnsafeRead v i

-- | Replace the element at the given position. No bounds checks are performed.
unsafeWrite :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m ()
{-# INLINE unsafeWrite #-}
unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
$ stToPrim
$ basicUnsafeWrite v i x

-- | Modify the element at the given position. No bounds checks are performed.
unsafeModify :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> a) -> Int -> m ()
{-# INLINE unsafeModify #-}
unsafeModify v f i = UNSAFE_CHECK(checkIndex) "unsafeModify" i (length v)
$ stToPrim
$ basicUnsafeRead v i >>= \x ->
basicUnsafeWrite v i (f x)

Expand All @@ -757,7 +766,7 @@ unsafeSwap :: (PrimMonad m, MVector v a)
{-# INLINE unsafeSwap #-}
unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
$ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
$ do
$ stToPrim $ do
x <- unsafeRead v i
y <- unsafeRead v j
unsafeWrite v i y
Expand All @@ -769,7 +778,7 @@ unsafeExchange :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m a
{-# INLINE unsafeExchange #-}
unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
$ do
$ stToPrim $ do
y <- unsafeRead v i
unsafeWrite v i x
return y
Expand All @@ -780,7 +789,7 @@ unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
-- | Set all elements of the vector to the given value.
set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
{-# INLINE set #-}
set = basicSet
set v = stToPrim . basicSet v

-- | Copy a vector. The two vectors must have the same length and may not
-- overlap.
Expand Down Expand Up @@ -820,7 +829,7 @@ unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
(length dst == length src)
$ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
(not (dst `overlaps` src))
$ (dst `seq` src `seq` basicUnsafeCopy dst src)
$ (dst `seq` src `seq` stToPrim (basicUnsafeCopy dst src))

-- | Move the contents of a vector. The two vectors must have the same
-- length, but this is not checked.
Expand All @@ -835,7 +844,7 @@ unsafeMove :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
{-# INLINE unsafeMove #-}
unsafeMove dst src = UNSAFE_CHECK(check) "unsafeMove" "length mismatch"
(length dst == length src)
$ (dst `seq` src `seq` basicUnsafeMove dst src)
$ (dst `seq` src `seq` stToPrim (basicUnsafeMove dst src))

-- Permutations
-- ------------
Expand Down
Loading