Skip to content

Commit

Permalink
Remove unnecessary Int-Word conversions (#1058)
Browse files Browse the repository at this point in the history
This simplifies the code and the GHC Core. It is not expected to affect
performance since Int-Word conversions are free at runtime.

Additionally,
* Remove the Nat synonym
* Document some preconditions
  • Loading branch information
meooow25 authored Nov 16, 2024
1 parent 5b3da8f commit 171b2e6
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 76 deletions.
25 changes: 5 additions & 20 deletions containers/src/Data/IntMap/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,7 @@ module Data.IntMap.Internal (
, showTree
, showTreeWith

-- * Internal types
, Nat

-- * Utility
, natFromInt
, intFromNat
, link
, linkKey
, linkWithMask
Expand Down Expand Up @@ -313,8 +308,9 @@ import Data.IntSet.Internal.IntTreeCommons
, branchMask
, TreeTreeBranch(..)
, treeTreeBranch
, i2w
)
import Utils.Containers.Internal.BitUtil
import Utils.Containers.Internal.BitUtil (shiftLL, shiftRL, iShiftRL)
import Utils.Containers.Internal.StrictPair

#ifdef __GLASGOW_HASKELL__
Expand All @@ -334,17 +330,6 @@ import Text.Read
import qualified Control.Category as Category


-- A "Nat" is a natural machine word (an unsigned Int)
type Nat = Word

natFromInt :: Key -> Nat
natFromInt = fromIntegral
{-# INLINE natFromInt #-}

intFromNat :: Nat -> Key
intFromNat = fromIntegral
{-# INLINE intFromNat #-}

{--------------------------------------------------------------------
Types
--------------------------------------------------------------------}
Expand Down Expand Up @@ -2146,7 +2131,7 @@ mergeA
-> Int -> f (IntMap a)
-> f (IntMap a)
linkA k1 t1 k2 t2
| natFromInt k1 < natFromInt k2 = binA p t1 t2
| i2w k1 < i2w k2 = binA p t1 t2
| otherwise = binA p t2 t1
where
m = branchMask k1 k2
Expand Down Expand Up @@ -3178,7 +3163,7 @@ fromSet f (IntSet.Tip kx bm) = buildTree f kx bm (IntSet.suffixBitMask + 1)
-- and we construct the IntMap from that half.
buildTree g !prefix !bmask bits = case bits of
0 -> Tip prefix (g prefix)
_ -> case intFromNat ((natFromInt bits) `shiftRL` 1) of
_ -> case bits `iShiftRL` 1 of
bits2
| bmask .&. ((1 `shiftLL` bits2) - 1) == 0 ->
buildTree g (prefix + bits2) (bmask `shiftRL` bits2) bits2
Expand Down Expand Up @@ -3552,7 +3537,7 @@ link k1 t1 k2 t2 = linkWithMask (branchMask k1 k2) k1 t1 k2 t2
-- `linkWithMask` is useful when the `branchMask` has already been computed
linkWithMask :: Int -> Key -> IntMap a -> Key -> IntMap a -> IntMap a
linkWithMask m k1 t1 k2 t2
| natFromInt k1 < natFromInt k2 = Bin p t1 t2
| i2w k1 < i2w k2 = Bin p t1 t2
| otherwise = Bin p t2 t1
where
p = Prefix (mask k1 m .|. m)
Expand Down
6 changes: 2 additions & 4 deletions containers/src/Data/IntMap/Strict/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,6 @@ import Data.IntSet.Internal.IntTreeCommons
(Key, Prefix(..), nomatch, left, signBranch, mask, branchMask)
import Data.IntMap.Internal
( IntMap (..)
, natFromInt
, intFromNat
, bin
, binCheckLeft
, binCheckRight
Expand Down Expand Up @@ -346,7 +344,7 @@ import Data.IntMap.Internal
, withoutKeys
)
import qualified Data.IntSet.Internal as IntSet
import Utils.Containers.Internal.BitUtil
import Utils.Containers.Internal.BitUtil (iShiftRL, shiftLL, shiftRL)
import Utils.Containers.Internal.StrictPair
import qualified Data.Foldable as Foldable

Expand Down Expand Up @@ -1056,7 +1054,7 @@ fromSet f (IntSet.Tip kx bm) = buildTree f kx bm (IntSet.suffixBitMask + 1)
-- one of them is nonempty and we construct the IntMap from that half.
buildTree g !prefix !bmask bits = case bits of
0 -> Tip prefix $! g prefix
_ -> case intFromNat ((natFromInt bits) `shiftRL` 1) of
_ -> case bits `iShiftRL` 1 of
bits2 | bmask .&. ((1 `shiftLL` bits2) - 1) == 0 ->
buildTree g (prefix + bits2) (bmask `shiftRL` bits2) bits2
| (bmask `shiftRL` bits2) .&. ((1 `shiftLL` bits2) - 1) == 0 ->
Expand Down
56 changes: 23 additions & 33 deletions containers/src/Data/IntSet/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ import Utils.Containers.Internal.Prelude hiding
(filter, foldr, foldl, foldl', foldMap, null, map)
import Prelude ()

import Utils.Containers.Internal.BitUtil
import Utils.Containers.Internal.BitUtil (iShiftRL, shiftLL, shiftRL)
import Utils.Containers.Internal.StrictPair
import Data.IntSet.Internal.IntTreeCommons
( Key
Expand All @@ -217,6 +217,7 @@ import Data.IntSet.Internal.IntTreeCommons
, branchMask
, TreeTreeBranch(..)
, treeTreeBranch
, i2w
)

#if __GLASGOW_HASKELL__
Expand All @@ -240,17 +241,6 @@ import Data.Functor.Identity (Identity(..))

infixl 9 \\{-This comment teaches CPP correct behaviour -}

-- A "Nat" is a natural machine word (an unsigned Int)
type Nat = Word

natFromInt :: Int -> Nat
natFromInt i = fromIntegral i
{-# INLINE natFromInt #-}

intFromNat :: Nat -> Int
intFromNat w = fromIntegral w
{-# INLINE intFromNat #-}

{--------------------------------------------------------------------
Operators
--------------------------------------------------------------------}
Expand Down Expand Up @@ -1388,10 +1378,10 @@ fromRange (lx,rx)
| m < suffixBitMask = Tip p (complement 0)
| otherwise = Bin (Prefix (p .|. m)) (goFull p (shr1 m)) (goFull (p .|. m) (shr1 m))
lbm :: Int -> Int
lbm p = intFromNat (lowestBitMask (natFromInt p))
lbm p = p .&. negate p -- lowest bit mask
{-# INLINE lbm #-}
shr1 :: Int -> Int
shr1 m = intFromNat (natFromInt m `shiftRL` 1)
shr1 m = m `iShiftRL` 1
{-# INLINE shr1 #-}

-- | \(O(n)\). Build a set from an ascending list of elements.
Expand Down Expand Up @@ -1621,7 +1611,7 @@ link k1 t1 k2 t2 = linkWithMask (branchMask k1 k2) k1 t1 k2 t2
-- `linkWithMask` is useful when the `branchMask` has already been computed
linkWithMask :: Int -> Key -> IntSet -> Key -> IntSet -> IntSet
linkWithMask m k1 t1 k2 t2
| natFromInt k1 < natFromInt k2 = Bin p t1 t2
| i2w k1 < i2w k2 = Bin p t1 t2
| otherwise = Bin p t2 t1
where
p = Prefix (mask k1 m .|. m)
Expand Down Expand Up @@ -1685,18 +1675,18 @@ bitmapOf x = bitmapOfSuffix (suffixOf x)
The signatures of methods in question are placed after this comment.
----------------------------------------------------------------------}

lowestBitSet :: Nat -> Int
highestBitSet :: Nat -> Int
foldlBits :: Int -> (a -> Int -> a) -> a -> Nat -> a
foldl'Bits :: Int -> (a -> Int -> a) -> a -> Nat -> a
foldrBits :: Int -> (Int -> a -> a) -> a -> Nat -> a
foldr'Bits :: Int -> (Int -> a -> a) -> a -> Nat -> a
lowestBitSet :: Word -> Int
highestBitSet :: Word -> Int
foldlBits :: Int -> (a -> Int -> a) -> a -> Word -> a
foldl'Bits :: Int -> (a -> Int -> a) -> a -> Word -> a
foldrBits :: Int -> (Int -> a -> a) -> a -> Word -> a
foldr'Bits :: Int -> (Int -> a -> a) -> a -> Word -> a
#if MIN_VERSION_base(4,11,0)
foldMapBits :: Semigroup a => Int -> (Int -> a) -> Nat -> a
foldMapBits :: Semigroup a => Int -> (Int -> a) -> Word -> a
#else
foldMapBits :: Monoid a => Int -> (Int -> a) -> Nat -> a
foldMapBits :: Monoid a => Int -> (Int -> a) -> Word -> a
#endif
takeWhileAntitoneBits :: Int -> (Int -> Bool) -> Nat -> Nat
takeWhileAntitoneBits :: Int -> (Int -> Bool) -> Word -> Word

{-# INLINE lowestBitSet #-}
{-# INLINE highestBitSet #-}
Expand All @@ -1707,26 +1697,26 @@ takeWhileAntitoneBits :: Int -> (Int -> Bool) -> Nat -> Nat
{-# INLINE foldMapBits #-}
{-# INLINE takeWhileAntitoneBits #-}

lowestBitMask :: Nat -> Nat
#if defined(__GLASGOW_HASKELL__)

lowestBitMask :: Word -> Word
lowestBitMask x = x .&. negate x
{-# INLINE lowestBitMask #-}

#if defined(__GLASGOW_HASKELL__)

lowestBitSet x = countTrailingZeros x

highestBitSet x = WORD_SIZE_IN_BITS - 1 - countLeadingZeros x

-- Reverse the order of bits in the Nat.
revNat :: Nat -> Nat
-- Reverse the order of bits in the Word.
revWord :: Word -> Word
#if WORD_SIZE_IN_BITS==32
revNat x1 = case ((x1 `shiftRL` 1) .&. 0x55555555) .|. ((x1 .&. 0x55555555) `shiftLL` 1) of
revWord x1 = case ((x1 `shiftRL` 1) .&. 0x55555555) .|. ((x1 .&. 0x55555555) `shiftLL` 1) of
x2 -> case ((x2 `shiftRL` 2) .&. 0x33333333) .|. ((x2 .&. 0x33333333) `shiftLL` 2) of
x3 -> case ((x3 `shiftRL` 4) .&. 0x0F0F0F0F) .|. ((x3 .&. 0x0F0F0F0F) `shiftLL` 4) of
x4 -> case ((x4 `shiftRL` 8) .&. 0x00FF00FF) .|. ((x4 .&. 0x00FF00FF) `shiftLL` 8) of
x5 -> ( x5 `shiftRL` 16 ) .|. ( x5 `shiftLL` 16);
#else
revNat x1 = case ((x1 `shiftRL` 1) .&. 0x5555555555555555) .|. ((x1 .&. 0x5555555555555555) `shiftLL` 1) of
revWord x1 = case ((x1 `shiftRL` 1) .&. 0x5555555555555555) .|. ((x1 .&. 0x5555555555555555) `shiftLL` 1) of
x2 -> case ((x2 `shiftRL` 2) .&. 0x3333333333333333) .|. ((x2 .&. 0x3333333333333333) `shiftLL` 2) of
x3 -> case ((x3 `shiftRL` 4) .&. 0x0F0F0F0F0F0F0F0F) .|. ((x3 .&. 0x0F0F0F0F0F0F0F0F) `shiftLL` 4) of
x4 -> case ((x4 `shiftRL` 8) .&. 0x00FF00FF00FF00FF) .|. ((x4 .&. 0x00FF00FF00FF00FF) `shiftLL` 8) of
Expand All @@ -1747,14 +1737,14 @@ foldl'Bits prefix f z bitmap = go bitmap z
where !bitmask = lowestBitMask bm
!bi = countTrailingZeros bitmask

foldrBits prefix f z bitmap = go (revNat bitmap) z
foldrBits prefix f z bitmap = go (revWord bitmap) z
where go 0 acc = acc
go bm acc = go (bm `xor` bitmask) ((f $! (prefix+(WORD_SIZE_IN_BITS-1)-bi)) acc)
where !bitmask = lowestBitMask bm
!bi = countTrailingZeros bitmask


foldr'Bits prefix f z bitmap = go (revNat bitmap) z
foldr'Bits prefix f z bitmap = go (revWord bitmap) z
where go 0 acc = acc
go bm !acc = go (bm `xor` bitmask) ((f $! (prefix+(WORD_SIZE_IN_BITS-1)-bi)) acc)
where !bitmask = lowestBitMask bm
Expand Down
14 changes: 7 additions & 7 deletions containers/src/Data/IntSet/Internal/IntTreeCommons.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ module Data.IntSet.Internal.IntTreeCommons
, treeTreeBranch
, mask
, branchMask
, i2w
) where

import Data.Bits (Bits(..))
import Utils.Containers.Internal.BitUtil (highestBitMask)
import Data.Bits (Bits(..), countLeadingZeros)
import Utils.Containers.Internal.BitUtil (wordSize)

#ifdef __GLASGOW_HASKELL__
import Language.Haskell.TH.Syntax (Lift)
Expand Down Expand Up @@ -149,18 +150,17 @@ mask i m = i .&. ((-m) `xor` m)
{-# INLINE mask #-}

-- | The first switching bit where the two prefixes disagree.
--
-- Precondition for defined behavior: p1 /= p2
branchMask :: Int -> Int -> Int
branchMask p1 p2 = w2i (highestBitMask (i2w (p1 `xor` p2)))
branchMask p1 p2 =
unsafeShiftL 1 (wordSize - 1 - countLeadingZeros (p1 `xor` p2))
{-# INLINE branchMask #-}

i2w :: Int -> Word
i2w = fromIntegral
{-# INLINE i2w #-}

w2i :: Word -> Int
w2i = fromIntegral
{-# INLINE w2i #-}

{--------------------------------------------------------------------
Notes
--------------------------------------------------------------------}
Expand Down
33 changes: 21 additions & 12 deletions containers/src/Utils/Containers/Internal/BitUtil.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE CPP #-}
#if !defined(TESTING) && defined(__GLASGOW_HASKELL__)
{-# LANGUAGE Safe #-}
#ifdef __GLASGOW_HASKELL__
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE Trustworthy #-}
#endif

#include "containers.h"
Expand Down Expand Up @@ -28,26 +29,34 @@
-- closely.

module Utils.Containers.Internal.BitUtil
( highestBitMask
, shiftLL
( shiftLL
, shiftRL
, wordSize
, iShiftRL
) where

import Data.Bits (unsafeShiftL, unsafeShiftR
, countLeadingZeros, finiteBitSize
)

-- | Return a word where only the highest bit is set.
highestBitMask :: Word -> Word
highestBitMask w = shiftLL 1 (wordSize - 1 - countLeadingZeros w)
{-# INLINE highestBitMask #-}
import Data.Bits (unsafeShiftL, unsafeShiftR, finiteBitSize)
#ifdef __GLASGOW_HASKELL__
import GHC.Exts (Int(..), uncheckedIShiftRL#)
#endif

-- Right and left logical shifts.
--
-- Precondition for defined behavior: 0 <= shift amount < wordSize
shiftRL, shiftLL :: Word -> Int -> Word
shiftRL = unsafeShiftR
shiftLL = unsafeShiftL

{-# INLINE wordSize #-}
wordSize :: Int
wordSize = finiteBitSize (0 :: Word)

-- Right logical shift.
--
-- Precondition for defined behavior: 0 <= shift amount < wordSize
iShiftRL :: Int -> Int -> Int
#ifdef __GLASGOW_HASKELL__
iShiftRL (I# x#) (I# sh#) = I# (uncheckedIShiftRL# x# sh#)
#else
iShiftRL x sh = fromIntegral (unsafeShiftR (fromIntegral x :: Word) sh)
#endif

0 comments on commit 171b2e6

Please sign in to comment.