Skip to content

Commit

Permalink
Allow arbitrary expressions in Size.
Browse files Browse the repository at this point in the history
Currently just a rewrite as NamedSize -> SizeExpr (Var ...) and ConstSize -> SizeExpr (Literal ...). There is no type check, and other expressions are not supported. Step for #1659.
  • Loading branch information
catvayor committed Mar 7, 2023
1 parent a7f169b commit 9476827
Show file tree
Hide file tree
Showing 21 changed files with 226 additions and 168 deletions.
5 changes: 3 additions & 2 deletions src/Futhark/Doc/Generator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,9 @@ relativise dest src =
concat (replicate (length (splitPath src) - 1) "../") ++ dest

dimDeclHtml :: Size -> DocM Html
dimDeclHtml (NamedSize v) = brackets <$> qualNameHtml v
dimDeclHtml (ConstSize n) = pure $ brackets $ toHtml (show n)
dimDeclHtml (SizeExpr (Var v _ _)) = brackets <$> qualNameHtml v
dimDeclHtml (SizeExpr (Literal (SignedValue (Int64Value n)) _)) = pure $ brackets $ toHtml (show n)
dimDeclHtml (SizeExpr _) = error "Arbitrary Expression not supported yet"
dimDeclHtml AnySize {} = pure $ brackets mempty

dimExpHtml :: SizeExp Info VName -> DocM Html
Expand Down
44 changes: 22 additions & 22 deletions src/Futhark/Internalise/Defunctionalise.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ replaceTypeSizes ::
TypeBase Size als
replaceTypeSizes substs = first onDim
where
onDim (NamedSize v) =
onDim (SizeExpr (Var v typ loc)) =
case M.lookup (qualLeaf v) substs of
Just (SubstNamed v') -> NamedSize v'
Just (SubstConst d) -> ConstSize d
Nothing -> NamedSize v
Just (SubstNamed v') -> SizeExpr (Var v' typ loc)
Just (SubstConst d) -> SizeExpr (Literal (SignedValue $ Int64Value d) loc)
Nothing -> SizeExpr (Var v typ loc)
onDim d = d

replaceStaticValSizes ::
Expand Down Expand Up @@ -265,15 +265,15 @@ arraySizes (Scalar (Sum cs)) = foldMap (foldMap arraySizes) cs
arraySizes (Scalar (TypeVar _ _ _ targs)) =
mconcat $ map f targs
where
f (TypeArgDim (NamedSize d) _) = S.singleton $ qualLeaf d
f (TypeArgDim (SizeExpr (Var d _ _)) _) = S.singleton $ qualLeaf d
f TypeArgDim {} = mempty
f (TypeArgType t _) = arraySizes t
arraySizes (Scalar Prim {}) = mempty
arraySizes (Array _ _ shape t) =
arraySizes (Scalar t) <> foldMap dimName (shapeDims shape)
where
dimName :: Size -> S.Set VName
dimName (NamedSize qn) = S.singleton $ qualLeaf qn
dimName (SizeExpr (Var qn _ _)) = S.singleton $ qualLeaf qn
dimName _ = mempty

patternArraySizes :: Pat -> S.Set VName
Expand All @@ -291,14 +291,14 @@ dimMapping ::
M.Map VName SizeSubst
dimMapping t1 t2 = execState (matchDims f t1 t2) mempty
where
f bound d1 (NamedSize d2)
f bound d1 (SizeExpr (Var d2 _ _))
| qualLeaf d2 `elem` bound = pure d1
f _ (NamedSize d1) (NamedSize d2) = do
f _ (SizeExpr (Var d1 typ loc)) (SizeExpr (Var d2 _ _)) = do
modify $ M.insert (qualLeaf d1) $ SubstNamed d2
pure $ NamedSize d1
f _ (NamedSize d1) (ConstSize d2) = do
pure $ SizeExpr $ Var d1 typ loc
f _ (SizeExpr (Var d1 typ loc)) (SizeExpr (Literal (SignedValue (Int64Value d2)) _)) = do
modify $ M.insert (qualLeaf d1) $ SubstConst d2
pure $ NamedSize d1
pure $ SizeExpr $ Var d1 typ loc
f _ d _ = pure d

dimMapping' ::
Expand All @@ -325,7 +325,7 @@ sizesToRename (RecordSV fs) =
sizesToRename (SumSV _ svs _) =
foldMap sizesToRename svs
sizesToRename (LambdaSV param _ _ _) =
freeInPat param
(M.foldrWithKey (\k _ -> S.insert k) S.empty $ unFV $ freeInPat param)
<> S.map identName (S.filter couldBeSize $ patIdents param)
where
couldBeSize ident =
Expand Down Expand Up @@ -751,7 +751,7 @@ defuncLet ::
DefM ([VName], [Pat], Exp, StaticVal)
defuncLet dims ps@(pat : pats) body (RetType ret_dims rettype)
| patternOrderZero pat = do
let bound_by_pat = (`S.member` freeInPat pat)
let bound_by_pat = (`S.member` (M.foldrWithKey (\k _ -> S.insert k) S.empty $ unFV $ freeInPat pat))
-- Take care to not include more size parameters than necessary.
(pat_dims, rest_dims) = partition bound_by_pat dims
env = envFromPat pat <> envFromDimNames pat_dims
Expand Down Expand Up @@ -786,24 +786,24 @@ sizesForAll bound_sizes params = do
tv = identityMapper {mapOnPatType = bitraverse onDim pure}
onDim (AnySize (Just v)) = do
modify $ S.insert v
pure $ NamedSize $ qualName v
pure $ SizeExpr $ Var (qualName v) (Info <$> Scalar $ Prim $ Unsigned Int64) mempty
onDim (AnySize Nothing) = do
v <- lift $ newVName "size"
modify $ S.insert v
pure $ NamedSize $ qualName v
onDim (NamedSize d) = do
pure $ SizeExpr $ Var (qualName v) (Info <$> Scalar $ Prim $ Unsigned Int64) mempty
onDim (SizeExpr (Var d typ loc)) = do
unless (qualLeaf d `S.member` bound) $
modify $
S.insert $
qualLeaf d
pure $ NamedSize d
pure $ SizeExpr $ Var d typ loc
onDim d = pure d

unRetType :: StructRetType -> StructType
unRetType (RetType [] t) = t
unRetType (RetType ext t) = first onDim t
where
onDim (NamedSize d) | qualLeaf d `elem` ext = AnySize Nothing
onDim (SizeExpr (Var d _ _)) | qualLeaf d `elem` ext = AnySize Nothing
onDim d = d

defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal)
Expand Down Expand Up @@ -1026,7 +1026,7 @@ liftValDec fname (RetType ret_dims ret) dims pats body = addValBind dec
mkExt v
| not $ v `S.member` bound_here = Just v
mkExt _ = Nothing
rettype_st = RetType (mapMaybe mkExt (S.toList (freeInType ret)) ++ ret_dims) ret
rettype_st = RetType (mapMaybe mkExt (S.toList (M.foldrWithKey (\k _ -> S.insert k) S.empty $ unFV $ freeInType ret)) ++ ret_dims) ret

dec =
ValBind
Expand Down Expand Up @@ -1139,7 +1139,7 @@ matchPatSV (Id vn (Info t) _) sv =
else dim_env <> M.singleton vn (Binding Nothing sv)
where
dim_env =
M.fromList $ map (,i64) $ S.toList $ freeInType t
M.fromList $ map (,i64) $ S.toList $ M.foldrWithKey (\k _ -> S.insert k) S.empty $ unFV $ freeInType t
i64 = Binding Nothing $ Dynamic $ Scalar $ Prim $ Signed Int64
matchPatSV (Wildcard _ _) _ = mempty
matchPatSV (PatAscription pat _ _) sv = matchPatSV pat sv
Expand Down Expand Up @@ -1261,7 +1261,7 @@ defuncValBind valbind@(ValBind _ name retdecl (Info (RetType ret_dims rettype))
-- applications of lifted functions, we don't properly update
-- the types in the return type annotation.
combineTypeShapes rettype $ first (anyDimIfNotBound bound_sizes) $ toStruct $ typeOf body'
ret_dims' = filter (`S.member` freeInType rettype') ret_dims
ret_dims' = filter (`S.member` (M.foldrWithKey (\k _ -> S.insert k) S.empty $ unFV $ freeInType rettype')) ret_dims
(missing_dims, params'') <- sizesForAll bound_sizes params'

pure
Expand All @@ -1283,7 +1283,7 @@ defuncValBind valbind@(ValBind _ name retdecl (Info (RetType ret_dims rettype))
sv
)
where
anyDimIfNotBound bound_sizes (NamedSize v)
anyDimIfNotBound bound_sizes (SizeExpr (Var v _ _))
| qualLeaf v `S.notMember` bound_sizes = AnySize $ Just $ qualLeaf v
anyDimIfNotBound _ d = d

Expand Down
10 changes: 6 additions & 4 deletions src/Futhark/Internalise/LiftLambdas.hs
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,17 @@ liftFunction fname tparams params (RetType dims ret) funbody = do
foldMap freeInType $ M.elems $ unFV immediate_free
sizes =
freeSizes $
sizes_in_free
<> foldMap freeInPat params
<> freeInType ret
M.foldrWithKey (\k _ -> S.insert k) S.empty $
unFV $
sizes_in_free
<> foldMap freeInPat params
<> freeInType ret
in M.toList $ unFV $ immediate_free <> (sizes `freeWithout` bound)

-- Those parameters that correspond to sizes must come first.
sizes_in_types =
foldMap freeInType (ret : map snd free ++ map patternStructType params)
isSize (v, _) = v `S.member` sizes_in_types
isSize (v, _) = v `M.member` (unFV sizes_in_types)
(free_dims, free_nondims) = partition isSize free

free_params =
Expand Down
26 changes: 13 additions & 13 deletions src/Futhark/Internalise/Monomorphise.hs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ type MonoType = TypeBase MonoSize ()
monoType :: TypeBase Size als -> MonoType
monoType = (`evalState` (0, mempty)) . traverseDims onDim . toStruct
where
onDim bound _ (NamedSize d)
onDim bound _ (SizeExpr (Var d _ _))
-- A locally bound size.
| qualLeaf d `S.member` bound = pure $ MonoAnon $ qualLeaf d
onDim _ _ d = do
Expand Down Expand Up @@ -267,7 +267,7 @@ sizesForPat pat = do
onDim (AnySize _) = do
v <- lift $ newVName "size"
modify (v :)
pure $ NamedSize $ qualName v
pure $ SizeExpr $ Var (qualName v) (Info <$> Scalar $ Prim $ Unsigned Int64) mempty
onDim d = pure d

transformAppRes :: AppRes -> MonoM AppRes
Expand Down Expand Up @@ -538,9 +538,9 @@ desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (Ret
[(Observe, xext, e1)]
(AppRes (Scalar $ Arrow mempty yp Observe ytype (RetType [] t)) [])
rettype' =
let onDim (NamedSize d)
| Named p <- xp, qualLeaf d == p = NamedSize $ qualName v1
| Named p <- yp, qualLeaf d == p = NamedSize $ qualName v2
let onDim (SizeExpr (Var d typ _))
| Named p <- xp, qualLeaf d == p = SizeExpr $ Var (qualName v1) typ loc
| Named p <- yp, qualLeaf d == p = SizeExpr $ Var (qualName v2) typ loc
onDim d = d
in first onDim rettype
body =
Expand Down Expand Up @@ -610,7 +610,7 @@ desugarIndexSection idxs (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do
desugarIndexSection _ t _ = error $ "desugarIndexSection: not a function type: " ++ prettyString t

noticeDims :: TypeBase Size as -> MonoM ()
noticeDims = mapM_ notice . freeInType
noticeDims = mapM_ notice . M.foldrWithKey (\k _ -> S.insert k) S.empty . unFV . freeInType
where
notice v = void $ transformFName mempty (qualName v) i64

Expand Down Expand Up @@ -676,11 +676,11 @@ dimMapping ::
DimInst
dimMapping t1 t2 = execState (matchDims f t1 t2) mempty
where
f bound d1 (NamedSize d2)
f bound d1 (SizeExpr (Var d2 _ _))
| qualLeaf d2 `elem` bound = pure d1
f _ (NamedSize d1) d2 = do
f _ (SizeExpr (Var d1 typ loc)) d2 = do
modify $ M.insert (qualLeaf d1) d2
pure $ NamedSize d1
pure $ SizeExpr $ Var d1 typ loc
f _ d _ = pure d

inferSizeArgs :: [TypeParam] -> StructType -> StructType -> [Exp]
Expand All @@ -689,9 +689,9 @@ inferSizeArgs tparams bind_t t =
where
tparamArg dinst tp =
case M.lookup (typeParamName tp) dinst of
Just (NamedSize d) ->
Just (SizeExpr (Var d _ _)) ->
Just $ Var d (Info i64) mempty
Just (ConstSize x) ->
Just (SizeExpr (Literal (SignedValue (Int64Value x)) _)) ->
Just $ Literal (SignedValue $ Int64Value $ fromIntegral x) mempty
_ ->
Just $ Literal (SignedValue $ Int64Value 0) mempty
Expand Down Expand Up @@ -853,9 +853,9 @@ typeSubstsM loc orig_t1 orig_t2 =
d <- lift $ lift $ newVName "d"
tell [TypeParamDim d loc]
put (ts, M.insert i d sizes)
pure $ NamedSize $ qualName d
pure $ SizeExpr $ Var (qualName d) (Info <$> Scalar $ Prim $ Unsigned Int64) mempty
Just d ->
pure $ NamedSize $ qualName d
pure $ SizeExpr $ Var (qualName d) (Info <$> Scalar $ Prim $ Unsigned Int64) mempty
onDim (MonoAnon v) = pure $ AnySize $ Just v

-- Perform a given substitution on the types in a pattern.
Expand Down
5 changes: 3 additions & 2 deletions src/Futhark/Internalise/TypesValues.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ internaliseDim ::
internaliseDim exts d =
case d of
E.AnySize _ -> Ext <$> newId
E.ConstSize n -> pure $ Free $ intConst I.Int64 $ toInteger n
E.NamedSize name -> pure $ namedDim name
E.SizeExpr (E.Literal (E.SignedValue (E.Int64Value n)) _) -> pure $ Free $ intConst I.Int64 $ toInteger n
E.SizeExpr (E.Var name _ _) -> pure $ namedDim name
E.SizeExpr _ -> error "Arbitrary Expression not supported yet"
where
namedDim (E.QualName _ name)
| Just x <- name `M.lookup` exts = I.Ext x
Expand Down
25 changes: 15 additions & 10 deletions src/Language/Futhark/FreeVars.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ freeWithout (FV x) y = FV $ M.filterWithKey keep x
ident :: IdentBase Info VName -> FV
ident v = FV $ M.singleton (identName v) (toStruct $ unInfo (identType v))

{-
size :: VName -> FV
size v = FV $ M.singleton v $ Scalar $ Prim $ Signed Int64
-- | A 'FV' with these names, considered to be sizes.
sizes :: S.Set VName -> FV
sizes = foldMap size
-}

-- | Compute the set of free variables of an expression.
freeInExp :: ExpBase Info VName -> FV
Expand All @@ -56,20 +58,20 @@ freeInExp expr = case expr of
freeInExpField (RecordFieldExplicit _ e _) = freeInExp e
freeInExpField (RecordFieldImplicit vn t _) = ident $ Ident vn t mempty
ArrayLit es t _ ->
foldMap freeInExp es <> sizes (freeInType $ unInfo t)
foldMap freeInExp es <> (freeInType $ unInfo t)
AppExp (Range e me incl _) _ ->
freeInExp e <> foldMap freeInExp me <> foldMap freeInExp incl
Var qn (Info t) _ -> FV $ M.singleton (qualLeaf qn) $ toStruct t
Ascript e _ _ -> freeInExp e
AppExp (Coerce e _ _) (Info ar) ->
freeInExp e <> sizes (freeInType (appResType ar))
freeInExp e <> (freeInType (appResType ar))
AppExp (LetPat let_sizes pat e1 e2 _) _ ->
freeInExp e1
<> ( (sizes (freeInPat pat) <> freeInExp e2)
<> ( ((freeInPat pat) <> freeInExp e2)
`freeWithout` (patNames pat <> S.fromList (map sizeName let_sizes))
)
AppExp (LetFun vn (tparams, pats, _, _, e1) e2 _) _ ->
( (freeInExp e1 <> sizes (foldMap freeInPat pats))
( (freeInExp e1 <> (foldMap freeInPat pats))
`freeWithout` ( foldMap patNames pats
<> S.fromList (map typeParamName tparams)
)
Expand All @@ -80,7 +82,7 @@ freeInExp expr = case expr of
Negate e _ -> freeInExp e
Not e _ -> freeInExp e
Lambda pats e0 _ (Info (_, RetType dims t)) _ ->
(sizes (foldMap freeInPat pats) <> freeInExp e0 <> sizes (freeInType t))
((foldMap freeInPat pats) <> freeInExp e0 <> (freeInType t))
`freeWithout` (foldMap patNames pats <> S.fromList dims)
OpSection {} -> mempty
OpSectionLeft _ _ e _ _ _ -> freeInExp e
Expand Down Expand Up @@ -116,7 +118,7 @@ freeInExp expr = case expr of
AppExp (Match e cs _) _ -> freeInExp e <> foldMap caseFV cs
where
caseFV (CasePat p eCase _) =
(sizes (freeInPat p) <> freeInExp eCase)
((freeInPat p) <> freeInExp eCase)
`freeWithout` patNames p

freeInDimIndex :: DimIndexBase Info VName -> FV
Expand All @@ -125,7 +127,7 @@ freeInDimIndex (DimSlice me1 me2 me3) =
foldMap (foldMap freeInExp) [me1, me2, me3]

-- | Free variables in pattern (including types of the bound identifiers).
freeInPat :: PatBase Info VName -> S.Set VName
freeInPat :: PatBase Info VName -> FV
freeInPat (TuplePat ps _) = foldMap freeInPat ps
freeInPat (RecordPat fs _) = foldMap (freeInPat . snd) fs
freeInPat (PatParens p _) = freeInPat p
Expand All @@ -137,7 +139,7 @@ freeInPat (PatConstr _ _ ps _) = foldMap freeInPat ps
freeInPat (PatAttr _ p _) = freeInPat p

-- | Free variables in the type (meaning those that are used in size expression).
freeInType :: TypeBase Size as -> S.Set VName
freeInType :: TypeBase Size as -> FV
freeInType t =
case t of
Array _ _ s a ->
Expand All @@ -149,7 +151,10 @@ freeInType t =
Scalar (Sum cs) ->
foldMap (foldMap freeInType) cs
Scalar (Arrow _ v _ t1 (RetType dims t2)) ->
S.filter (notV v) $ S.filter (`notElem` dims) $ freeInType t1 <> freeInType t2
FV $
M.filterWithKey (\key _ -> notV v key) $
M.filterWithKey (\key _ -> key `notElem` dims) $
(unFV $ freeInType t1) <> (unFV $ freeInType t2)
Scalar (TypeVar _ _ _ targs) ->
foldMap typeArgDims targs
where
Expand All @@ -159,5 +164,5 @@ freeInType t =
notV Unnamed = const True
notV (Named v) = (/= v)

onSize (NamedSize qn) = S.singleton $ qualLeaf qn
onSize (SizeExpr (Var qn (Info ty) _)) = FV $ M.singleton (qualLeaf qn) $ toStruct ty
onSize _ = mempty
Loading

0 comments on commit 9476827

Please sign in to comment.