Skip to content

Commit

Permalink
Make source language Apply AST node multi-argument.
Browse files Browse the repository at this point in the history
This is a deviation from the concrete syntax, but humans tend to think
of function calls having multiple arguments.  Also, the AST had to
keep a lot of useless metadata around to express the results of the
intermediate applications.

And again, it is related to making #1872 more feasible.
  • Loading branch information
athas committed Feb 13, 2023
1 parent 9e4ab92 commit cfdf138
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 258 deletions.
295 changes: 137 additions & 158 deletions src/Futhark/Internalise/Defunctionalise.hs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ defuncExp (AppExp (If e1 e2 e3 loc) res) = do
(e2', sv) <- defuncExp e2
(e3', _) <- defuncExp e3
pure (AppExp (If e1' e2' e3' loc) res, sv)
defuncExp e@(AppExp Apply {} _) = defuncApply 0 e
defuncExp (AppExp (Apply f args loc) (Info appres)) =
defuncApply f (fmap (first unInfo) args) appres loc
defuncExp (Negate e0 loc) = do
(e0', sv) <- defuncExp e0
pure (Negate e0' loc, sv)
Expand Down Expand Up @@ -706,9 +707,8 @@ etaExpand e_t e = do
let e' =
foldl'
( \e1 (e2, t2, argtypes) ->
AppExp
(Apply e1 e2 (Info (diet t2, Nothing)) mempty)
(Info (AppRes (foldFunType argtypes ret) []))
mkApply e1 [(diet t2, Nothing, e2)] $
AppRes (foldFunType argtypes ret) []
)
e
$ zip3 vars (map (snd . snd) ps) (drop 1 $ tails $ map snd ps)
Expand Down Expand Up @@ -807,164 +807,21 @@ unRetType (RetType ext t) = first onDim t
onDim (NamedSize d) | qualLeaf d `elem` ext = AnySize Nothing
onDim d = d

-- | Defunctionalize an application expression at a given depth of application.
-- Calls to dynamic (first-order) functions are preserved at much as possible,
-- but a new lifted function is created if a dynamic function is only partially
-- applied.
defuncApply :: Int -> Exp -> DefM (Exp, StaticVal)
defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do
let (argtypes, _) = unfoldFunType ret
(e1', sv1) <- defuncApply (depth + 1) e1
case sv1 of
LambdaSV pat e0_t e0 closure_env -> do
(e2', sv2) <- defuncExp e2
let env' = matchPatSV pat sv2
dims = mempty
(e0', sv) <-
localNewEnv (env' <> closure_env) $
defuncExp e0

let closure_pat = buildEnvPat dims closure_env
pat' = updatePat pat sv2

globals <- asks fst

-- Lift lambda to top-level function definition. We put in
-- a lot of effort to try to infer the uniqueness attributes
-- of the lifted function, but this is ultimately all a sham
-- and a hack. There is some piece we're missing.
let params = [closure_pat, pat']
params_for_rettype = params ++ svParams sv1 ++ svParams sv2
svParams (LambdaSV sv_pat _ _ _) = [sv_pat]
svParams _ = []
lifted_rettype = buildRetType closure_env params_for_rettype (unRetType e0_t) $ typeOf e0'

already_bound =
globals
<> S.fromList dims
<> S.map identName (foldMap patIdents params)

more_dims =
S.toList $
S.filter (`S.notMember` already_bound) $
foldMap patternArraySizes params

-- Embed some information about the original function
-- into the name of the lifted function, to make the
-- result slightly more human-readable.
liftedName i (Var f _ _) =
"defunc_" ++ show i ++ "_" ++ baseString (qualLeaf f)
liftedName i (AppExp (Apply f _ _ _) _) =
liftedName (i + 1) f
liftedName _ _ = "defunc"

-- Ensure that no parameter sizes are AnySize. The internaliser
-- expects this. This is easy, because they are all
-- first-order.
let bound_sizes = S.fromList (dims <> more_dims) <> globals
(missing_dims, params') <- sizesForAll bound_sizes params

fname <- newNameFromString $ liftedName (0 :: Int) e1
liftValDec
fname
(RetType [] $ toStruct lifted_rettype)
(dims ++ more_dims ++ missing_dims)
params'
e0'

let t1 = toStruct $ typeOf e1'
t2 = toStruct $ typeOf e2'
d1 = Observe
d2 = Observe
fname' = qualName fname
fname'' =
Var
fname'
( Info
( Scalar . Arrow mempty Unnamed d1 t1 . RetType [] $
Scalar . Arrow mempty Unnamed d2 t2 $
RetType [] lifted_rettype
)
)
loc

callret = AppRes (combineTypeShapes ret lifted_rettype) ext

innercallret =
AppRes
(Scalar $ Arrow mempty Unnamed d2 t2 $ RetType [] lifted_rettype)
[]

pure
( AppExp
( Apply
( AppExp
(Apply fname'' e1' (Info (Observe, Nothing)) loc)
(Info innercallret)
)
e2'
d
loc
)
(Info callret),
sv
)

-- If e1 is a dynamic function, we just leave the application in place,
-- but we update the types since it may be partially applied or return
-- a higher-order term.
DynamicFun _ sv -> do
(e2', _) <- defuncExp e2
let (argtypes', rettype) = dynamicFunType sv argtypes
restype = foldFunType argtypes' (RetType [] rettype) `setAliases` aliases ret
callret = AppRes (combineTypeShapes ret restype) ext
apply_e = AppExp (Apply e1' e2' d loc) (Info callret)
pure (apply_e, sv)
-- Propagate the 'IntrinsicsSV' until we reach the outermost application,
-- where we construct a dynamic static value with the appropriate type.
IntrinsicSV -> do
e2' <- defuncSoacExp e2
let e' = AppExp (Apply e1' e2' d loc) t
intrinsicOrHole argtypes e' sv1
HoleSV {} -> do
(e2', _) <- defuncExp e2
let e' = AppExp (Apply e1' e2' d loc) t
intrinsicOrHole argtypes e' sv1
_ ->
error $
"Application of an expression\n"
++ prettyString e1
++ "\nthat is neither a static lambda "
++ "nor a dynamic function, but has static value:\n"
++ show sv1
where
intrinsicOrHole argtypes e' sv
| depth == 0 =
-- If the intrinsic is fully applied, then we are done.
-- Otherwise we need to eta-expand it and recursively
-- defunctionalise. XXX: might it be better to simply
-- eta-expand immediately any time we encounter a
-- non-fully-applied intrinsic?
if null argtypes
then pure (e', Dynamic $ typeOf e)
else do
(pats, body, tp) <- etaExpand (typeOf e') e'
defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) mempty
| otherwise = pure (e', sv)
defuncApply depth e@(Var qn (Info t) loc) = do
defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction e@(Var qn (Info t) loc) num_args = do
let (argtypes, _) = unfoldFunType t
sv <- lookupVar (toStruct t) (qualLeaf qn)

case sv of
DynamicFun _ _
| fullyApplied sv depth -> do
| fullyApplied sv num_args -> do
-- We still need to update the types in case the dynamic
-- function returns a higher-order term.
let (argtypes', rettype) = dynamicFunType sv argtypes
pure (Var qn (Info (foldFunType argtypes' $ RetType [] rettype)) loc, sv)
| otherwise -> do
fname <- newVName $ "dyn_" <> baseString (qualLeaf qn)
let (pats, e0, sv') = liftDynFun (prettyString qn) sv depth
let (pats, e0, sv') = liftDynFun (prettyString qn) sv num_args
(argtypes', rettype) = dynamicFunType sv' argtypes
dims' = mempty

Expand All @@ -985,8 +842,131 @@ defuncApply depth e@(Var qn (Info t) loc) = do
)
IntrinsicSV -> pure (e, IntrinsicSV)
_ -> pure (Var qn (Info (typeFromSV sv)) loc, sv)
defuncApply depth (Parens e _) = defuncApply depth e
defuncApply _ expr = defuncExp expr
defuncApplyFunction e _ = defuncExp e

-- Embed some information about the original function
-- into the name of the lifted function, to make the
-- result slightly more human-readable.
liftedName :: Int -> Exp -> String
liftedName i (Var f _ _) =
"defunc_" ++ show i ++ "_" ++ baseString (qualLeaf f)
liftedName i (AppExp (Apply f _ _) _) =
liftedName (i + 1) f
liftedName _ _ = "defunc"

defuncApplyArg ::
String ->
(Exp, StaticVal) ->
(((Diet, Maybe VName), Exp), [(Diet, StructType)]) ->
DefM (Exp, StaticVal)
defuncApplyArg fname_s (f', f_sv@(LambdaSV pat lam_e_t lam_e closure_env)) (((d, argext), arg), _) = do
(arg', arg_sv) <- defuncExp arg
let env' = matchPatSV pat arg_sv
dims = mempty
(lam_e', sv) <-
localNewEnv (env' <> closure_env) $
defuncExp lam_e

let closure_pat = buildEnvPat dims closure_env
pat' = updatePat pat arg_sv

globals <- asks fst

-- Lift lambda to top-level function definition. We put in
-- a lot of effort to try to infer the uniqueness attributes
-- of the lifted function, but this is ultimately all a sham
-- and a hack. There is some piece we're missing.
let params = [closure_pat, pat']
params_for_rettype = params ++ svParams f_sv ++ svParams arg_sv
svParams (LambdaSV sv_pat _ _ _) = [sv_pat]
svParams _ = []
lifted_rettype = buildRetType closure_env params_for_rettype (unRetType lam_e_t) $ typeOf lam_e'

already_bound =
globals
<> S.fromList dims
<> S.map identName (foldMap patIdents params)

more_dims =
S.toList $
S.filter (`S.notMember` already_bound) $
foldMap patternArraySizes params

-- Ensure that no parameter sizes are AnySize. The internaliser
-- expects this. This is easy, because they are all
-- first-order.
let bound_sizes = S.fromList (dims <> more_dims) <> globals
(missing_dims, params') <- sizesForAll bound_sizes params

fname <- newNameFromString fname_s
liftValDec
fname
(RetType [] $ toStruct lifted_rettype)
(dims ++ more_dims ++ missing_dims)
params'
lam_e'

let f_t = toStruct $ typeOf f'
arg_t = toStruct $ typeOf arg'
d1 = Observe
fname_t = foldFunType [(d1, f_t), (d, arg_t)] $ RetType [] lifted_rettype
fname' = Var (qualName fname) (Info fname_t) (srclocOf arg)
callret = AppRes lifted_rettype []

pure
( mkApply fname' [(Observe, Nothing, f'), (Observe, argext, arg')] callret,
sv
)
-- If 'f' is a dynamic function, we just leave the application in
-- place, but we update the types since it may be partially
-- applied or return a higher-order value.
defuncApplyArg _ (f', DynamicFun _ sv) (((d, argext), arg), argtypes) = do
(arg', _) <- defuncExp arg
let (argtypes', rettype) = dynamicFunType sv argtypes
restype = foldFunType argtypes' (RetType [] rettype)
callret = AppRes restype []
apply_e = mkApply f' [(d, argext, arg')] callret
pure (apply_e, sv)
--
defuncApplyArg _ (_, sv) _ =
error $ "defuncApplyArg: cannot apply StaticVal\n" <> show sv

updateReturn :: AppRes -> Exp -> Exp
updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) =
AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2)
updateReturn _ e = e

defuncApply :: Exp -> NE.NonEmpty ((Diet, Maybe VName), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal)
defuncApply f args appres loc = do
(f', f_sv) <- defuncApplyFunction f (length args)
case f_sv of
IntrinsicSV -> do
args' <- fmap (first Info) <$> traverse (traverse defuncSoacExp) args
let e' = AppExp (Apply f' args' loc) (Info appres)
intrinsicOrHole e'
HoleSV {} -> do
args' <- fmap (first Info) <$> traverse (traverse $ fmap fst . defuncExp) args
let e' = AppExp (Apply f' args' loc) (Info appres)
intrinsicOrHole e'
_ -> do
let fname = liftedName 0 f
(argtypes, _) = unfoldFunType $ typeOf f
fmap (first $ updateReturn appres) $
foldM (defuncApplyArg fname) (f', f_sv) $
NE.zip args $
NE.tails argtypes
where
intrinsicOrHole e' = do
-- If the intrinsic is fully applied, then we are done.
-- Otherwise we need to eta-expand it and recursively
-- defunctionalise. XXX: might it be better to simply eta-expand
-- immediately any time we encounter a non-fully-applied
-- intrinsic?
if null $ fst $ unfoldFunType $ appResType appres
then pure (e', Dynamic $ appResType appres)
else do
(pats, body, tp) <- etaExpand (typeOf e') e'
defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) mempty

-- | Check if a 'StaticVal' and a given application depth corresponds
-- to a fully applied dynamic function.
Expand Down Expand Up @@ -1181,11 +1161,10 @@ matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t
matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType t
matchPatSV pat sv =
error $
"Tried to match pattern "
"Tried to match pattern\n"
++ prettyString pat
++ " with static value "
++ "\n with static value\n"
++ show sv
++ "."

orderZeroSV :: StaticVal -> Bool
orderZeroSV Dynamic {} = True
Expand Down Expand Up @@ -1235,9 +1214,9 @@ updatePat pat (Dynamic t) = updatePat pat (svFromType t)
updatePat pat (HoleSV t _) = updatePat pat (svFromType t)
updatePat pat sv =
error $
"Tried to update pattern "
"Tried to update pattern\n"
++ prettyString pat
++ "to reflect the static value "
++ "\nto reflect the static value\n"
++ show sv

-- | Convert a record (or tuple) type to a record static value. This is used for
Expand Down
11 changes: 5 additions & 6 deletions src/Futhark/Internalise/Exps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1443,14 +1443,13 @@ data Function
deriving (Show)

findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)])
findFuncall (E.Apply f arg (Info (_, argext)) _)
| E.AppExp f_e _ <- f =
let (f_e', args) = findFuncall f_e
in (f_e', args ++ [(arg, argext)])
findFuncall (E.Apply f args _)
| E.Var fname _ _ <- f =
(FunctionName fname, [(arg, argext)])
(FunctionName fname, map onArg $ NE.toList args)
| E.Hole (Info t) loc <- f =
(FunctionHole t loc, [(arg, argext)])
(FunctionHole t loc, map onArg $ NE.toList args)
where
onArg (Info (_, argext), e) = (e, argext)
findFuncall e =
error $ "Invalid function expression in application:\n" ++ prettyString e

Expand Down
10 changes: 6 additions & 4 deletions src/Futhark/Internalise/LiftLambdas.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ replacing v e = local $ \env ->

existentials :: Exp -> S.Set VName
existentials e =
let here = case e of
AppExp (Apply _ _ (Info (_, pdim)) _) (Info res) ->
S.fromList (maybeToList pdim ++ appResExt res)
let onArg (Info (_, pdim), _) =
maybeToList pdim
here = case e of
AppExp (Apply _ args _) (Info res) ->
S.fromList (foldMap onArg args <> appResExt res)
AppExp _ (Info res) ->
S.fromList (appResExt res)
_ ->
Expand Down Expand Up @@ -129,7 +131,7 @@ liftFunction fname tparams params (RetType dims ret) funbody = do
apply f [] = f
apply f (p : rem_ps) =
let inner_ret = AppRes (fromStruct (augType rem_ps)) mempty
inner = AppExp (Apply f (freeVar p) (Info (Observe, Nothing)) mempty) (Info inner_ret)
inner = mkApply f [(Observe, Nothing, freeVar p)] inner_ret
in apply inner rem_ps

transformExp :: Exp -> LiftM Exp
Expand Down
Loading

0 comments on commit cfdf138

Please sign in to comment.