Skip to content

Commit

Permalink
FIXED: `Clash.Annotations.BitRepresentation.Deriving.deriveAnnotation…
Browse files Browse the repository at this point in the history
…` no longer has quadratic complexity in the size of the constructors and fields.
  • Loading branch information
rowanG077 committed May 16, 2022
1 parent 6ee18b0 commit a9870fa
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 47 deletions.
1 change: 1 addition & 0 deletions changelog/2022-05-12T18_44_42+02_00_cse_deriveAnnotation
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FIXED: `Clash.Annotations.BitRepresentation.Deriving.deriveAnnotation` no longer has quadratic complexity in the size of the constructors and fields.
151 changes: 104 additions & 47 deletions clash-prelude/src/Clash/Annotations/BitRepresentation/Deriving.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import Data.Bits
(shiftL, shiftR, complement, (.&.), (.|.), zeroBits, popCount, bit, testBit,
Bits, setBit)
import Data.Data (Data)
import Data.Containers.ListUtils (nubOrd)
import Data.List
(mapAccumL, zipWith4, sortOn, partition)
import Data.Typeable (Typeable)
Expand Down Expand Up @@ -253,11 +254,26 @@ conName c = case c of
InfixC _ nm _ -> nm
_ -> error $ "No GADT support"

constrFieldSizes
:: Con
-> (Name, [Q Exp])
constrFieldSizes con = do
(conName con, map typeSize $ fieldTypes con)
mkInTTypedLet :: String -> Q Exp -> (Q Dec, Q Exp)
mkInTTypedLet nm qe = do
let nm' = mkName nm
let te = qe >>= \case
ListE [] -> appTypeE qe [t|Int|]
_ -> qe
(valD (varP nm') (normalB te) [], varE nm')

fieldSizeLets :: [[Type]] -> ([Q Dec], [[Q Exp]])
fieldSizeLets fieldtypess = (fieldSizeDecls, fieldSizessExps)
where
nums = map show [(0 :: Int)..]
uqFieldTypes = nubOrd (concat fieldtypess)
uqFieldSizes = map typeSize uqFieldTypes
(fieldSizeDecls, szVars) = unzip $ zipWith
(\i sz -> mkInTTypedLet ("_f" ++ i) sz)
nums
uqFieldSizes
tySizeMap = Map.fromList (zip uqFieldTypes szVars)
fieldSizessExps = map (map (tySizeMap Map.!)) fieldtypess

complementInteger :: Int -> Integer -> Integer
complementInteger 0 _i = 0
Expand Down Expand Up @@ -301,67 +317,90 @@ buildConstrRepr dataSize constrName fieldAnns constrMask constrValue = [|
countConstructor :: [Int] -> [(BitMask, Value)]
countConstructor ns = zip (repeat mask) (map toInteger ns)
where
maskSize = bitsNeeded $ toInteger $ maximum ns + 1
maskSize = bitsNeeded $ toInteger $ maximum @[] @Int ns + 1
mask = 2^maskSize - 1

oneHotConstructor :: [Int] -> [(BitMask, Value)]
oneHotConstructor ns = zip values values
where
values = [shiftL 1 n | n <- ns]

overlapFieldAnnsL :: [[Q Exp]] -> [[Q Exp]]
overlapFieldAnnsL fieldSizess = map go fieldSizess
overlapFieldAnnsL :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
overlapFieldAnnsL fieldSizess = ([maxDecl], resExp)
where
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map sum $fieldSizess' |]
go fieldSizes =
(maxDecl, maxExp) = mkInTTypedLet "_maxf" maxConstrSize
resExp = map go fieldSizess
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map (sum @[] @Int) $fieldSizess' |]
maxConstrSize = [| maximum @[] @Int $constructorSizes - 1 |]
go fieldsizes =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| maximum $constructorSizes - 1 |]
fieldSizes
maxExp
fieldsizes

overlapFieldAnnsR :: [[Q Exp]] -> [[Q Exp]]
overlapFieldAnnsR fieldSizess = map go fieldSizess
overlapFieldAnnsR :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
overlapFieldAnnsR fieldSizess = (sumFieldDecl, resExp)
where
fieldSizess' = listE $ map listE fieldSizess
constructorSizes = [| map sum $fieldSizess' |]
go fieldSizes =
resExp = zipWith go fieldSizess sumFieldExp

nums = map show [(0 :: Int) ..]

(sumFieldDecl, sumFieldExp)
= unzip $ zipWith
(\fs i -> mkInTTypedLet ("_sumf" ++ i) [|sum @[] @Int $(listE fs)|])
fieldSizess
nums

go fieldSizes sumFieldsSize =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| maximum $constructorSizes - (maximum $constructorSizes - sum $(listE fieldSizes)) - 1 |]
[| $sumFieldsSize - 1 |]
fieldSizes

wideFieldAnns :: [[Q Exp]] -> [[Q Exp]]
wideFieldAnns fieldSizess = zipWith id (map go constructorOffsets) fieldSizess
wideFieldAnns :: [[Q Exp]] -> ([Q Dec], [[Q Exp]])
wideFieldAnns fieldSizess = (decs, resExp)
where
constructorSizes =
map (AppE (VarE 'sum) <$>) (map listE fieldSizess)

constructorOffsets :: [Q Exp]
constructorOffsets =
init $
scanl
(\offset size -> [| $offset + $size |])
[| 0 |]
constructorSizes

dataSize = [| sum $(listE constructorSizes) |]

decs = (dataSizeDec:constrSizeDecs) ++ constrOffsetDecs
resExp = zipWith id (map go constrOffsetsExps) fieldSizess
nums = map show [(0 :: Int) ..]

constrSizeExps :: [Q Exp]
(constrSizeDecs, constrSizeExps)
= unzip $ zipWith
(\fs i -> mkInTTypedLet ("_sumf" ++ i) [|sum @[] @Int $(listE fs)|])
fieldSizess
nums

constrOffsetsExps :: [Q Exp]
(last -> constrOffsetDecs, constrOffsetsExps) =
unzip $ init $ scanl
(\(ds, offset) (size, i) ->
let e = [| $offset + $size |]
(d, ve) = mkInTTypedLet ("_constroffset" ++ i) e
in (d:ds, ve)
)
([], [| 0 |])
(zip constrSizeExps nums)

dataSizeExp :: Q Exp
(dataSizeDec, dataSizeExp)
= mkInTTypedLet "_widedatasize" [| sum @[] @Int $(listE constrSizeExps) - 1 |]
go :: Q Exp -> [Q Exp] -> [Q Exp]
go offset fieldSizes =
snd $
mapAccumL
(\start size -> ([| $start - $size |], [| bitmask $start $size |]))
[| $dataSize - 1 - $offset |]
[| $dataSizeExp - $offset |]
fieldSizes

-- | Derive DataRepr' for a specific type.
deriveDataRepr
:: ([Int] -> [(BitMask, Value)])
-- ^ Constructor derivator
-> ([[Q Exp]] -> [[Q Exp]])
-> ([[Q Exp]] -> ([Q Dec], [[Q Exp]]) )
-- ^ Field derivator
-> Derivator
deriveDataRepr constrDerivator fieldsDerivator typ = do
Expand All @@ -370,34 +409,52 @@ deriveDataRepr constrDerivator fieldsDerivator typ = do
(TyConI (DataD [] _constrName vars _kind dConstructors _clauses)) ->
let varMap = Map.fromList $ zip (map tyVarBndrName vars) typeArgs in
let resolvedConstructors = map (resolveCon varMap) dConstructors in do
let nums = map show [(0 :: Int)..]
let fieldtypess = map fieldTypes resolvedConstructors

let (fieldSzDecs, fieldSizess) = fieldSizeLets fieldtypess

-- Get sizes and names of all constructors
let
(constrNames, fieldSizess) =
unzip $ map constrFieldSizes resolvedConstructors
let constrNames = map conName resolvedConstructors

let
(constrMasks, constrValues) =
unzip $ constrDerivator [0..length dConstructors - 1]

let constrSize = 1 + (msb $ maximum constrMasks)
let fieldAnns = fieldsDerivator fieldSizess
let fieldAnnsFlat = listE $ concat fieldAnns
let constrSize = 1 + (msb $ maximum @[] @Integer constrMasks)
let (fieldDecs, fieldAnns) = fieldsDerivator fieldSizess

-- extract field annotations into declarations
let mkAnnDecl i j an = mkInTTypedLet ("_fa" ++ i ++ "_" ++ j) an
let
fieldAnnTup =
zipWith (\i -> zipWith (mkAnnDecl i) nums) nums fieldAnns

let
(fieldAnnDecs, fieldAnnVars) =
(concat $ map (map fst) fieldAnnTup, map (map snd) fieldAnnTup)

let fieldAnnsFlat = listE $ concat fieldAnnVars

let dataSize | null $ concat fieldAnns = [| 0 |]
| otherwise = [| 1 + (msb $ maximum $ $fieldAnnsFlat) |]
| otherwise = [| 1 + (msb $ maximum @[] @Integer $ $fieldAnnsFlat) |]

-- Extract data size into a declaration
let (dataSizeDec, dataSizeExp) = mkInTTypedLet "_datasize" dataSize

let decls = (dataSizeDec:fieldSzDecs) ++ fieldDecs ++ fieldAnnDecs

-- Determine at which bits various fields start
let constrReprs = zipWith4
(buildConstrRepr dataSize)
(buildConstrRepr dataSizeExp)
constrNames
fieldAnns
fieldAnnVars
constrMasks
constrValues

[| DataReprAnn
letE decls [| DataReprAnn
$(liftQ $ return typ)
($dataSize + constrSize)
($dataSizeExp + constrSize)
$(listE constrReprs) |]
_ ->
fail $ "Could not derive dataRepr for: " ++ show info
Expand Down

0 comments on commit a9870fa

Please sign in to comment.