Skip to content

Commit

Permalink
FIXED: `Clash.Annotations.BitRepresentation.Deriving.deriveAnnotation…
Browse files Browse the repository at this point in the history
…` no longer has quadratic complexity in the size of the constructors and fields.
  • Loading branch information
rowanG077 committed May 12, 2022
1 parent 6ee18b0 commit 28f62b8
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 42 deletions.
1 change: 1 addition & 0 deletions changelog/2022-05-12T18_44_42+02_00_cse_deriveAnnotation
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: `Clash.Annotations.BitRepresentation.Deriving.deriveAnnotation` no longer has quadratic complexity in the size of the constructors and fields.
154 changes: 112 additions & 42 deletions clash-prelude/src/Clash/Annotations/BitRepresentation/Deriving.hs
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,21 @@ conName c = case c of
InfixC _ nm _ -> nm
_ -> error $ "No GADT support"

constrFieldSizes
:: Con
-> (Name, [Q Exp])
constrFieldSizes con = do
(conName con, map typeSize $ fieldTypes con)
fieldSizeLets :: [[Type]] -> ([Q Dec], [[Q Exp]])
fieldSizeLets fieldtypess = (fieldSizeDecls, fieldSizessVars)
where
nums = map show [(0 :: Int)..]
uqFieldTypes = Set.toList
$ foldl (\s ts -> (Set.fromList ts) `Set.union` s )
Set.empty fieldtypess
uqFieldSizes = map typeSize uqFieldTypes
fieldIds = map (mkName . ("f" ++)) nums
fieldSizeDecls = zipWith
(\nm sz -> valD (varP nm) (normalB sz) [])
fieldIds
uqFieldSizes
tySizeMap = Map.fromList (zip uqFieldTypes (map varE fieldIds))
fieldSizessVars = map (map (tySizeMap Map.!)) fieldtypess

complementInteger :: Int -> Integer -> Integer
complementInteger 0 _i = 0
Expand Down Expand Up @@ -309,59 +319,97 @@ oneHotConstructor ns = zip values values
where
values = [shiftL 1 n | n <- ns]

overlapFieldAnnsL :: [[Q Exp]] -> [[Q Exp]]
overlapFieldAnnsL fieldSizess = map go fieldSizess
overlapFieldAnnsL :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
overlapFieldAnnsL fieldSizess = (decls, resExp)
where
fieldSizess' = listE $ map listE fieldSizess
maxFName = mkName "maxf"
decls = [valD (varP maxFName) (normalB maxConstrSize) []]
resExp = map go fieldSizess
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map sum $fieldSizess' |]
go fieldSizes =
maxConstrSize = [| maximum $constructorSizes - 1 |]
go fieldsizes =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| maximum $constructorSizes - 1 |]
fieldSizes
(varE maxFName)
fieldsizes

overlapFieldAnnsR :: [[Q Exp]] -> [[Q Exp]]
overlapFieldAnnsR fieldSizess = map go fieldSizess
overlapFieldAnnsR :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
overlapFieldAnnsR fieldSizess = (decls, resExp)
where
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map sum $fieldSizess' |]
go fieldSizes =
decls = sumFieldDecl
resExp = zipWith go fieldSizess sumFieldExp

nums = map show [(0 :: Int) ..]

(sumFieldExp, sumFieldDecl)
= unzip $ zipWith
(\fs i ->
let e = [|sum $(listE fs)|]
nm = mkName ("sumf_" ++ i)
in (varE nm, valD (varP nm) (normalB e) [])
)
fieldSizess
nums

go fieldSizes sumFieldsSize =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| maximum $constructorSizes - (maximum $constructorSizes - sum $(listE fieldSizes)) - 1 |]
[| $sumFieldsSize - 1 |]
fieldSizes

wideFieldAnns :: [[Q Exp]] -> [[Q Exp]]
wideFieldAnns fieldSizess = zipWith id (map go constructorOffsets) fieldSizess
wideFieldAnns :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
wideFieldAnns fieldSizess = (decls, resExp)
where
constructorSizes =
map (AppE (VarE 'sum) <$>) (map listE fieldSizess)

constructorOffsets :: [Q Exp]
constructorOffsets =
init $
scanl
(\offset size -> [| $offset + $size |])
[| 0 |]
constructorSizes

dataSize = [| sum $(listE constructorSizes) |]
decls = (dataSizeDecl:constrSizeDec) ++ constrOffsetDecls
resExp = zipWith id (map go constrOffsets) fieldSizess
nums = map show [(0 :: Int) ..]

constrSizeExps :: [Q Exp]
(constrSizeExps, constrSizeDec)
= unzip $ zipWith
(\fs i ->
let e = [|sum $(listE fs)|]
nm = mkName ("sumf_" ++ i)
in (varE nm, valD (varP nm) (normalB e) [])
)
fieldSizess
nums

constrOffsets :: [Q Exp]
(constrOffsets, last -> constrOffsetDecls) =
unzip $ init $ scanl
(\(offset, ds) (size, i) ->
let nm = mkName ("constroffset_" ++ i)
e = [| $offset + $size |]

d = valD (varP nm) (normalB e) []
in (varE nm, d:ds)
)
([| 0 |], [])
(zip constrSizeExps nums)

dataSizeExp :: Q Exp
(dataSizeExp ,dataSizeDecl) =
let nm = mkName "widedatasize"
e = [| sum $(listE constrSizeExps) - 1 |]
in (varE nm, valD (varP nm) (normalB e) [])

go :: Q Exp -> [Q Exp] -> [Q Exp]
go offset fieldSizes =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| $dataSize - 1 - $offset |]
[| $dataSizeExp - $offset |]
fieldSizes

-- | Derive DataRepr' for a specific type.
deriveDataRepr
:: ([Int] -> [(BitMask, Value)])
-- ^ Constructor derivator
-> ([[Q Exp]] -> [[Q Exp]])
-> ([[Q Exp]] -> ([Q Dec], [[Q Exp]]) )
-- ^ Field derivator
-> Derivator
deriveDataRepr constrDerivator fieldsDerivator typ = do
Expand All @@ -370,34 +418,56 @@ deriveDataRepr constrDerivator fieldsDerivator typ = do
(TyConI (DataD [] _constrName vars _kind dConstructors _clauses)) ->
let varMap = Map.fromList $ zip (map tyVarBndrName vars) typeArgs in
let resolvedConstructors = map (resolveCon varMap) dConstructors in do
let nums = map show [(0 :: Int)..]
let fieldtypess = map fieldTypes resolvedConstructors

let (fieldSzDecls, fieldSizess) = fieldSizeLets fieldtypess

-- Get sizes and names of all constructors
let
(constrNames, fieldSizess) =
unzip $ map constrFieldSizes resolvedConstructors
let constrNames = map conName resolvedConstructors

let
(constrMasks, constrValues) =
unzip $ constrDerivator [0..length dConstructors - 1]

let constrSize = 1 + (msb $ maximum constrMasks)
let fieldAnns = fieldsDerivator fieldSizess
let fieldAnnsFlat = listE $ concat fieldAnns
let (fieldDecls, fieldAnns) = fieldsDerivator fieldSizess

-- extract field annotations into declarations
let mkAnnDecl i j an = let nm = mkName ("fa_" ++ i ++ "_" ++ j)
in (varE nm, valD (varP nm) (normalB an) [])
let
fieldAnnTup =
zipWith (\i -> zipWith (mkAnnDecl i) nums) nums fieldAnns

let
(fieldAnnVars, fieldAnnDecls) =
(map (map fst) fieldAnnTup, concat $ map (map snd) fieldAnnTup)

let fieldAnnsFlat = listE $ concat fieldAnnVars

let dataSize | null $ concat fieldAnns = [| 0 |]
| otherwise = [| 1 + (msb $ maximum $ $fieldAnnsFlat) |]

-- Extract data size into a declaration
let
(dataSizeExp, dataSizeDec) =
let nm = mkName "datasize"
in (varE nm, valD (varP nm) (normalB dataSize) [])

let decls = (dataSizeDec:fieldSzDecls) ++ fieldDecls ++ fieldAnnDecls

-- Determine at which bits various fields start
let constrReprs = zipWith4
(buildConstrRepr dataSize)
(buildConstrRepr dataSizeExp)
constrNames
fieldAnns
fieldAnnVars
constrMasks
constrValues

[| DataReprAnn
letE decls [| DataReprAnn
$(liftQ $ return typ)
($dataSize + constrSize)
($dataSizeExp + constrSize)
$(listE constrReprs) |]
_ ->
fail $ "Could not derive dataRepr for: " ++ show info
Expand Down

0 comments on commit 28f62b8

Please sign in to comment.