Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various deriving bitrepr fixes #2209

Merged
merged 5 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/2022-05-12T18_44_42+02_00_cse_deriveAnnotation
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: `Clash.Annotations.BitRepresentation.Deriving.deriveAnnotation` no longer has quadratic complexity in the size of the constructors and fields.
1 change: 1 addition & 0 deletions changelog/2022-05-16T10_06_57+02_00_bitrepr_symbols
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: Added support for symbols in types while deriving bit representation.
1 change: 1 addition & 0 deletions changelog/2022-05-16T10_13_47+02_00_bitrepr_promoted
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: Added support for promoted data types while deriving bit representations.
1 change: 1 addition & 0 deletions changelog/2022-05-16T10_15_05+02_00_bitrepr_resolve_types
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: Fully resolve type synonyms when deriving bit representations.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-|
Copyright : (C) 2018, Google Inc.
2022, LUMI GUIDE FIETSDETECTIE B.V.
License : BSD2 (see the file LICENSE)
Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com>

Expand All @@ -15,12 +16,13 @@ module Clash.Annotations.BitRepresentation.ClashLib
) where

import Clash.Annotations.BitRepresentation.Internal
(Type'(AppTy',ConstTy',LitTy'))
(Type'(..))
import qualified Clash.Annotations.BitRepresentation.Util as BitRepresentation
import qualified Clash.Core.Type as C
import Clash.Core.Name (nameOcc)
import qualified Clash.Netlist.Types as Netlist
import Clash.Util (curLoc)
import qualified Data.Text as T (pack)

-- Convert Core type to BitRepresentation type
coreToType'
Expand All @@ -37,6 +39,8 @@ coreToType' (C.ConstTy (C.TyCon name)) =
return $ ConstTy' (nameOcc name)
coreToType' (C.LitTy (C.NumTy n)) =
return $ LitTy' n
coreToType' (C.LitTy (C.SymTy lit)) =
return $ SymLitTy' (T.pack lit)
coreToType' e =
Left $ $(curLoc) ++ "Unexpected type: " ++ show e

Expand Down
192 changes: 127 additions & 65 deletions clash-prelude/src/Clash/Annotations/BitRepresentation/Deriving.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,35 @@ import Clash.Annotations.BitRepresentation.Util
import qualified Clash.Annotations.BitRepresentation.Util
as Util

import Clash.Annotations.Primitive (hasBlackBox)
import Clash.Annotations.Primitive (hasBlackBox)
import Clash.Class.BitPack
(BitPack, BitSize, pack, packXWith, unpack)
import Clash.Class.Resize (resize)
import Language.Haskell.TH.Compat (mkTySynInstD)
import Clash.Sized.BitVector (BitVector, low, (++#))
import Clash.Class.Resize (resize)
import Language.Haskell.TH.Compat (mkTySynInstD)
import Clash.Sized.BitVector (BitVector, low, (++#))
import Clash.Sized.Internal.BitVector (undefined#)
import Control.DeepSeq (NFData)
import Control.Monad (forM)
import Control.Applicative (liftA3)
import Control.DeepSeq (NFData)
import Control.Monad (forM)
import Data.Bits
(shiftL, shiftR, complement, (.&.), (.|.), zeroBits, popCount, bit, testBit,
Bits, setBit)
import Data.Data (Data)
import Data.Data (Data)
import Data.Containers.ListUtils (nubOrd)
import Data.List
(mapAccumL, zipWith4, sortOn, partition)
import Data.Typeable (Typeable)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import qualified Data.Set as Set
import Data.Proxy (Proxy(..))
import GHC.Exts (Int(I#))
import GHC.Generics (Generic)
import GHC.Integer.Logarithms (integerLog2#)
import GHC.TypeLits (natVal)
import Data.Typeable (Typeable)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import qualified Data.Set as Set
import Data.Proxy (Proxy(..))
import GHC.Exts (Int(I#))
import GHC.Generics (Generic)
import GHC.Integer.Logarithms (integerLog2#)
import GHC.TypeLits (natVal)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.Datatype (resolveTypeSynonyms)

-- | Used to track constructor bits in packed derivation
data BitMaskOrigin
Expand Down Expand Up @@ -253,11 +256,23 @@ conName c = case c of
InfixC _ nm _ -> nm
_ -> error $ "No GADT support"

constrFieldSizes
:: Con
-> (Name, [Q Exp])
constrFieldSizes con = do
(conName con, map typeSize $ fieldTypes con)
mkLet :: String -> Q Exp -> (Q Dec, Q Exp)
mkLet nm qe = do
let nm' = mkName nm
(valD (varP nm') (normalB qe) [], varE nm')

fieldSizeLets :: [[Type]] -> ([Q Dec], [[Q Exp]])
fieldSizeLets fieldtypess = (fieldSizeDecls, fieldSizessExps)
where
nums = map show [(0 :: Int)..]
uqFieldTypes = nubOrd (concat fieldtypess)
uqFieldSizes = map typeSize uqFieldTypes
(fieldSizeDecls, szVars) = unzip $ zipWith
(\i sz -> mkLet ("_f" ++ i) sz)
nums
uqFieldSizes
tySizeMap = Map.fromList (zip uqFieldTypes szVars)
fieldSizessExps = map (map (tySizeMap Map.!)) fieldtypess

complementInteger :: Int -> Integer -> Integer
complementInteger 0 _i = 0
Expand Down Expand Up @@ -309,59 +324,82 @@ oneHotConstructor ns = zip values values
where
values = [shiftL 1 n | n <- ns]

overlapFieldAnnsL :: [[Q Exp]] -> [[Q Exp]]
overlapFieldAnnsL fieldSizess = map go fieldSizess
overlapFieldAnnsL :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
overlapFieldAnnsL fieldSizess = ([maxDecl], resExp)
where
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map sum $fieldSizess' |]
go fieldSizes =
(maxDecl, maxExp) = mkLet "_maxf" maxConstrSize
resExp = map go fieldSizess
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map (sum @[] @Int) $fieldSizess' |]
maxConstrSize = [| maximum $constructorSizes - 1 |]
go fieldsizes =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| maximum $constructorSizes - 1 |]
fieldSizes
maxExp
fieldsizes

overlapFieldAnnsR :: [[Q Exp]] -> [[Q Exp]]
overlapFieldAnnsR fieldSizess = map go fieldSizess
overlapFieldAnnsR :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
overlapFieldAnnsR fieldSizess = (sumFieldDecl, resExp)
where
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map sum $fieldSizess' |]
go fieldSizes =
resExp = zipWith go fieldSizess sumFieldExp

nums = map show [(0 :: Int) ..]

(sumFieldDecl, sumFieldExp)
= unzip $ zipWith
(\fs i -> mkLet ("_sumf" ++ i) [|sum @[] @Int $(listE fs)|])
fieldSizess
nums

go fieldSizes sumFieldsSize =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| maximum $constructorSizes - (maximum $constructorSizes - sum $(listE fieldSizes)) - 1 |]
[| $sumFieldsSize - 1 |]
fieldSizes

wideFieldAnns :: [[Q Exp]] -> [[Q Exp]]
wideFieldAnns fieldSizess = zipWith id (map go constructorOffsets) fieldSizess
wideFieldAnns :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
wideFieldAnns fieldSizess = (decs, resExp)
where
constructorSizes =
map (AppE (VarE 'sum) <$>) (map listE fieldSizess)

constructorOffsets :: [Q Exp]
constructorOffsets =
init $
scanl
(\offset size -> [| $offset + $size |])
[| 0 |]
constructorSizes

dataSize = [| sum $(listE constructorSizes) |]

decs = (dataSizeDec:constrSizeDecs) ++ constrOffsetDecs
resExp = zipWith id (map go constrOffsetsExps) fieldSizess
nums = map show [(0 :: Int) ..]

constrSizeExps :: [Q Exp]
(constrSizeDecs, constrSizeExps)
= unzip $ zipWith
(\fs i -> mkLet ("_sumf" ++ i) [|sum @[] @Int $(listE fs)|])
fieldSizess
nums

constrOffsetsExps :: [Q Exp]
(last -> constrOffsetDecs, constrOffsetsExps) =
unzip $ init $ scanl
(\(ds, offset) (size, i) ->
let e = [| $offset + $size |]
(d, ve) = mkLet ("_constroffset" ++ i) e
in (d:ds, ve)
)
([], [| 0 |])
(zip constrSizeExps nums)

dataSizeExp :: Q Exp
(dataSizeDec, dataSizeExp)
= mkLet "_widedatasize" [| sum @[] @Int $(listE constrSizeExps) - 1 |]
go :: Q Exp -> [Q Exp] -> [Q Exp]
go offset fieldSizes =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| $dataSize - 1 - $offset |]
[| $dataSizeExp - $offset |]
fieldSizes

-- | Derive DataRepr' for a specific type.
deriveDataRepr
:: ([Int] -> [(BitMask, Value)])
-- ^ Constructor derivator
-> ([[Q Exp]] -> [[Q Exp]])
-> ([[Q Exp]] -> ([Q Dec], [[Q Exp]]) )
-- ^ Field derivator
-> Derivator
deriveDataRepr constrDerivator fieldsDerivator typ = do
Expand All @@ -370,34 +408,54 @@ deriveDataRepr constrDerivator fieldsDerivator typ = do
(TyConI (DataD [] _constrName vars _kind dConstructors _clauses)) ->
let varMap = Map.fromList $ zip (map tyVarBndrName vars) typeArgs in
let resolvedConstructors = map (resolveCon varMap) dConstructors in do
let nums = map show [(0 :: Int)..]
let fieldtypess = map fieldTypes resolvedConstructors

let (fieldSzDecs, fieldSizess) = fieldSizeLets fieldtypess

-- Get sizes and names of all constructors
let
(constrNames, fieldSizess) =
unzip $ map constrFieldSizes resolvedConstructors
let constrNames = map conName resolvedConstructors

let
(constrMasks, constrValues) =
unzip $ constrDerivator [0..length dConstructors - 1]

let constrSize = 1 + (msb $ maximum constrMasks)
let fieldAnns = fieldsDerivator fieldSizess
let fieldAnnsFlat = listE $ concat fieldAnns
let constrSize = 1 + (msb $ maximum @[] @Integer constrMasks)
let (fieldDecs, fieldAnns) = fieldsDerivator fieldSizess

-- extract field annotations into declarations
let mkAnnDecl i j an = mkLet ("_fa" ++ i ++ "_" ++ j) an
let
fieldAnnTup =
zipWith (\i -> zipWith (mkAnnDecl i) nums) nums fieldAnns

let
(fieldAnnDecs, fieldAnnVars) =
(concat $ map (map fst) fieldAnnTup, map (map snd) fieldAnnTup)

let fieldAnnsFlat = listE $ concat fieldAnnVars

let dataSize | null $ concat fieldAnns = [| 0 |]
| otherwise = [| 1 + (msb $ maximum $ $fieldAnnsFlat) |]
| otherwise = [| 1 + (msb $ maximum @[] @Integer $ $fieldAnnsFlat) |]

-- Extract data size into a declaration
let (dataSizeDec, dataSizeExp) = mkLet "_datasize" dataSize

let decls = (dataSizeDec:fieldSzDecs) ++ fieldDecs ++ fieldAnnDecs

-- Determine at which bits various fields start
let constrReprs = zipWith4
(buildConstrRepr dataSize)
(buildConstrRepr dataSizeExp)
constrNames
fieldAnns
fieldAnnVars
constrMasks
constrValues

[| DataReprAnn
$(liftQ $ return typ)
($dataSize + constrSize)
resolvedType <- resolveTypeSynonyms typ

letE decls [| DataReprAnn
$(liftQ $ return resolvedType)
($dataSizeExp + constrSize)
$(listE constrReprs) |]
_ ->
fail $ "Could not derive dataRepr for: " ++ show info
Expand Down Expand Up @@ -727,8 +785,11 @@ derivePackedAnnotation = deriveAnnotation packedDerivator
collectDataReprs :: Q [DataReprAnn]
collectDataReprs = do
thisMod <- thisModule
go [thisMod] Set.empty []
unresolved <- go [thisMod] Set.empty []
mapM resolveTyps unresolved
where
resolveTyps (DataReprAnn t s c)
= liftA3 DataReprAnn (resolveTypeSynonyms t) (pure s) (pure c)
go [] _visited acc = return acc
go (x:xs) visited acc
| x `Set.member` visited = go xs visited acc
Expand Down Expand Up @@ -929,8 +990,9 @@ deriveBitPack :: Q Type -> Q [Dec]
deriveBitPack typQ = do
anns <- collectDataReprs
typ <- typQ
rTyp <- resolveTypeSynonyms typ

ann <- case filter (\(DataReprAnn t _ _) -> t == typ) anns of
ann <- case filter (\(DataReprAnn t _ _) -> t == rTyp) anns of
[a] -> return a
[] -> fail "No custom bit annotation found."
_ -> fail "Overlapping bit annotations found."
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-|
Copyright : (C) 2018, Google Inc.
2022, LUMI GUIDE FIETSDETECTIE B.V.
License : BSD2 (see the file LICENSE)
Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com>
-}
Expand Down Expand Up @@ -44,6 +45,8 @@ data Type'
-- ^ Qualified name of type
| LitTy' Integer
-- ^ Numeral literal (used in BitVector 10, for example)
| SymLitTy' Text.Text
-- ^ Symbol literal (used in for example (Signal "System" Int))
deriving (Generic, NFData, Eq, Typeable, Hashable, Ord, Show)

-- | Internal version of DataRepr
Expand Down Expand Up @@ -90,8 +93,10 @@ thTypeToType' :: TH.Type -> Type'
thTypeToType' ty = go ty
where
go (TH.ConT name') = ConstTy' (thToText name')
go (TH.PromotedT name') = ConstTy' (thToText name')
go (TH.AppT ty1 ty2) = AppTy' (go ty1) (go ty2)
go (TH.LitT (TH.NumTyLit n)) = LitTy' n
go (TH.LitT (TH.StrTyLit lit)) = SymLitTy' (Text.pack lit)
go _ = error $ "Unsupported type: " ++ show ty

-- | Convenience type for index built by buildCustomReprs
Expand Down
25 changes: 20 additions & 5 deletions clash-prelude/tests/Clash/Tests/DerivingDataRepr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ import Data.Maybe (Maybe(..))
---------------------------------------------------------
------------ DERIVING SIMPLE REPRESENTATIONS ------------
---------------------------------------------------------
oneHotOverlapRepr :: DataReprAnn
oneHotOverlapRepr = $( (simpleDerivator OneHot OverlapL) =<< [t| Train |] )
oneHotOverlapLRepr :: DataReprAnn
oneHotOverlapLRepr = $( (simpleDerivator OneHot OverlapL) =<< [t| Train |] )

oneHotOverlapRepr' :: DataReprAnn
oneHotOverlapRepr' =
oneHotOverlapLRepr' :: DataReprAnn
oneHotOverlapLRepr' =
DataReprAnn
$(liftQ [t| Train |])
8
Expand All @@ -30,6 +30,20 @@ oneHotOverlapRepr' =
, ConstrRepr 'Toy 128 128 []
]

oneHotOverlapRRepr :: DataReprAnn
oneHotOverlapRRepr = $( (simpleDerivator OneHot OverlapR) =<< [t| Train |] )

oneHotOverlapRRepr' :: DataReprAnn
oneHotOverlapRRepr' =
DataReprAnn
$(liftQ [t| Train |])
8
[ ConstrRepr 'Passenger 16 16 [0b0011]
, ConstrRepr 'Freight 32 32 [0b1100, 0b0011]
, ConstrRepr 'Maintenance 64 64 []
, ConstrRepr 'Toy 128 128 []
]

oneHotOverlapReprRec :: DataReprAnn
oneHotOverlapReprRec = $( (simpleDerivator OneHot OverlapL) =<< [t| Headphones |] )

Expand Down Expand Up @@ -132,7 +146,8 @@ packedMaybeRGB' =
-- MAIN
tests :: TestTree
tests = testGroup "DerivingDataRepr"
[ testCase "OneHotOverlap" $ oneHotOverlapRepr @?= oneHotOverlapRepr'
[ testCase "OneHotOverlapL" $ oneHotOverlapLRepr @?= oneHotOverlapLRepr'
, testCase "OneHotOverlapR" $ oneHotOverlapRRepr @?= oneHotOverlapRRepr'
, testCase "OneHotOverlapRec" $ oneHotOverlapReprRec @?= oneHotOverlapReprRec'
, testCase "OneHotOverlapInfix" $ oneHotOverlapReprInfix @?= oneHotOverlapReprInfix'
, testCase "OneHotWide" $ oneHotWideRepr @?= oneHotWideRepr'
Expand Down