Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make source language Apply AST node multi-argument. #1875

Merged
merged 1 commit into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 137 additions & 158 deletions src/Futhark/Internalise/Defunctionalise.hs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ defuncExp (AppExp (If e1 e2 e3 loc) res) = do
(e2', sv) <- defuncExp e2
(e3', _) <- defuncExp e3
pure (AppExp (If e1' e2' e3' loc) res, sv)
defuncExp e@(AppExp Apply {} _) = defuncApply 0 e
defuncExp (AppExp (Apply f args loc) (Info appres)) =
defuncApply f (fmap (first unInfo) args) appres loc
defuncExp (Negate e0 loc) = do
(e0', sv) <- defuncExp e0
pure (Negate e0' loc, sv)
Expand Down Expand Up @@ -706,9 +707,8 @@ etaExpand e_t e = do
let e' =
foldl'
( \e1 (e2, t2, argtypes) ->
AppExp
(Apply e1 e2 (Info (diet t2, Nothing)) mempty)
(Info (AppRes (foldFunType argtypes ret) []))
mkApply e1 [(diet t2, Nothing, e2)] $
AppRes (foldFunType argtypes ret) []
)
e
$ zip3 vars (map (snd . snd) ps) (drop 1 $ tails $ map snd ps)
Expand Down Expand Up @@ -807,164 +807,21 @@ unRetType (RetType ext t) = first onDim t
onDim (NamedSize d) | qualLeaf d `elem` ext = AnySize Nothing
onDim d = d

-- | Defunctionalize an application expression at a given depth of application.
-- Calls to dynamic (first-order) functions are preserved at much as possible,
-- but a new lifted function is created if a dynamic function is only partially
-- applied.
defuncApply :: Int -> Exp -> DefM (Exp, StaticVal)
defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do
let (argtypes, _) = unfoldFunType ret
(e1', sv1) <- defuncApply (depth + 1) e1
case sv1 of
LambdaSV pat e0_t e0 closure_env -> do
(e2', sv2) <- defuncExp e2
let env' = matchPatSV pat sv2
dims = mempty
(e0', sv) <-
localNewEnv (env' <> closure_env) $
defuncExp e0

let closure_pat = buildEnvPat dims closure_env
pat' = updatePat pat sv2

globals <- asks fst

-- Lift lambda to top-level function definition. We put in
-- a lot of effort to try to infer the uniqueness attributes
-- of the lifted function, but this is ultimately all a sham
-- and a hack. There is some piece we're missing.
let params = [closure_pat, pat']
params_for_rettype = params ++ svParams sv1 ++ svParams sv2
svParams (LambdaSV sv_pat _ _ _) = [sv_pat]
svParams _ = []
lifted_rettype = buildRetType closure_env params_for_rettype (unRetType e0_t) $ typeOf e0'

already_bound =
globals
<> S.fromList dims
<> S.map identName (foldMap patIdents params)

more_dims =
S.toList $
S.filter (`S.notMember` already_bound) $
foldMap patternArraySizes params

-- Embed some information about the original function
-- into the name of the lifted function, to make the
-- result slightly more human-readable.
liftedName i (Var f _ _) =
"defunc_" ++ show i ++ "_" ++ baseString (qualLeaf f)
liftedName i (AppExp (Apply f _ _ _) _) =
liftedName (i + 1) f
liftedName _ _ = "defunc"

-- Ensure that no parameter sizes are AnySize. The internaliser
-- expects this. This is easy, because they are all
-- first-order.
let bound_sizes = S.fromList (dims <> more_dims) <> globals
(missing_dims, params') <- sizesForAll bound_sizes params

fname <- newNameFromString $ liftedName (0 :: Int) e1
liftValDec
fname
(RetType [] $ toStruct lifted_rettype)
(dims ++ more_dims ++ missing_dims)
params'
e0'

let t1 = toStruct $ typeOf e1'
t2 = toStruct $ typeOf e2'
d1 = Observe
d2 = Observe
fname' = qualName fname
fname'' =
Var
fname'
( Info
( Scalar . Arrow mempty Unnamed d1 t1 . RetType [] $
Scalar . Arrow mempty Unnamed d2 t2 $
RetType [] lifted_rettype
)
)
loc

callret = AppRes (combineTypeShapes ret lifted_rettype) ext

innercallret =
AppRes
(Scalar $ Arrow mempty Unnamed d2 t2 $ RetType [] lifted_rettype)
[]

pure
( AppExp
( Apply
( AppExp
(Apply fname'' e1' (Info (Observe, Nothing)) loc)
(Info innercallret)
)
e2'
d
loc
)
(Info callret),
sv
)

-- If e1 is a dynamic function, we just leave the application in place,
-- but we update the types since it may be partially applied or return
-- a higher-order term.
DynamicFun _ sv -> do
(e2', _) <- defuncExp e2
let (argtypes', rettype) = dynamicFunType sv argtypes
restype = foldFunType argtypes' (RetType [] rettype) `setAliases` aliases ret
callret = AppRes (combineTypeShapes ret restype) ext
apply_e = AppExp (Apply e1' e2' d loc) (Info callret)
pure (apply_e, sv)
-- Propagate the 'IntrinsicsSV' until we reach the outermost application,
-- where we construct a dynamic static value with the appropriate type.
IntrinsicSV -> do
e2' <- defuncSoacExp e2
let e' = AppExp (Apply e1' e2' d loc) t
intrinsicOrHole argtypes e' sv1
HoleSV {} -> do
(e2', _) <- defuncExp e2
let e' = AppExp (Apply e1' e2' d loc) t
intrinsicOrHole argtypes e' sv1
_ ->
error $
"Application of an expression\n"
++ prettyString e1
++ "\nthat is neither a static lambda "
++ "nor a dynamic function, but has static value:\n"
++ show sv1
where
intrinsicOrHole argtypes e' sv
| depth == 0 =
-- If the intrinsic is fully applied, then we are done.
-- Otherwise we need to eta-expand it and recursively
-- defunctionalise. XXX: might it be better to simply
-- eta-expand immediately any time we encounter a
-- non-fully-applied intrinsic?
if null argtypes
then pure (e', Dynamic $ typeOf e)
else do
(pats, body, tp) <- etaExpand (typeOf e') e'
defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) mempty
| otherwise = pure (e', sv)
defuncApply depth e@(Var qn (Info t) loc) = do
defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction e@(Var qn (Info t) loc) num_args = do
let (argtypes, _) = unfoldFunType t
sv <- lookupVar (toStruct t) (qualLeaf qn)

case sv of
DynamicFun _ _
| fullyApplied sv depth -> do
| fullyApplied sv num_args -> do
-- We still need to update the types in case the dynamic
-- function returns a higher-order term.
let (argtypes', rettype) = dynamicFunType sv argtypes
pure (Var qn (Info (foldFunType argtypes' $ RetType [] rettype)) loc, sv)
| otherwise -> do
fname <- newVName $ "dyn_" <> baseString (qualLeaf qn)
let (pats, e0, sv') = liftDynFun (prettyString qn) sv depth
let (pats, e0, sv') = liftDynFun (prettyString qn) sv num_args
(argtypes', rettype) = dynamicFunType sv' argtypes
dims' = mempty

Expand All @@ -985,8 +842,131 @@ defuncApply depth e@(Var qn (Info t) loc) = do
)
IntrinsicSV -> pure (e, IntrinsicSV)
_ -> pure (Var qn (Info (typeFromSV sv)) loc, sv)
defuncApply depth (Parens e _) = defuncApply depth e
defuncApply _ expr = defuncExp expr
defuncApplyFunction e _ = defuncExp e

-- Embed some information about the original function
-- into the name of the lifted function, to make the
-- result slightly more human-readable.
liftedName :: Int -> Exp -> String
liftedName i (Var f _ _) =
"defunc_" ++ show i ++ "_" ++ baseString (qualLeaf f)
liftedName i (AppExp (Apply f _ _) _) =
liftedName (i + 1) f
liftedName _ _ = "defunc"

defuncApplyArg ::
String ->
(Exp, StaticVal) ->
(((Diet, Maybe VName), Exp), [(Diet, StructType)]) ->
DefM (Exp, StaticVal)
defuncApplyArg fname_s (f', f_sv@(LambdaSV pat lam_e_t lam_e closure_env)) (((d, argext), arg), _) = do
(arg', arg_sv) <- defuncExp arg
let env' = matchPatSV pat arg_sv
dims = mempty
(lam_e', sv) <-
localNewEnv (env' <> closure_env) $
defuncExp lam_e

let closure_pat = buildEnvPat dims closure_env
pat' = updatePat pat arg_sv

globals <- asks fst

-- Lift lambda to top-level function definition. We put in
-- a lot of effort to try to infer the uniqueness attributes
-- of the lifted function, but this is ultimately all a sham
-- and a hack. There is some piece we're missing.
let params = [closure_pat, pat']
params_for_rettype = params ++ svParams f_sv ++ svParams arg_sv
svParams (LambdaSV sv_pat _ _ _) = [sv_pat]
svParams _ = []
lifted_rettype = buildRetType closure_env params_for_rettype (unRetType lam_e_t) $ typeOf lam_e'

already_bound =
globals
<> S.fromList dims
<> S.map identName (foldMap patIdents params)

more_dims =
S.toList $
S.filter (`S.notMember` already_bound) $
foldMap patternArraySizes params

-- Ensure that no parameter sizes are AnySize. The internaliser
-- expects this. This is easy, because they are all
-- first-order.
let bound_sizes = S.fromList (dims <> more_dims) <> globals
(missing_dims, params') <- sizesForAll bound_sizes params

fname <- newNameFromString fname_s
liftValDec
fname
(RetType [] $ toStruct lifted_rettype)
(dims ++ more_dims ++ missing_dims)
params'
lam_e'

let f_t = toStruct $ typeOf f'
arg_t = toStruct $ typeOf arg'
d1 = Observe
fname_t = foldFunType [(d1, f_t), (d, arg_t)] $ RetType [] lifted_rettype
fname' = Var (qualName fname) (Info fname_t) (srclocOf arg)
callret = AppRes lifted_rettype []

pure
( mkApply fname' [(Observe, Nothing, f'), (Observe, argext, arg')] callret,
sv
)
-- If 'f' is a dynamic function, we just leave the application in
-- place, but we update the types since it may be partially
-- applied or return a higher-order value.
defuncApplyArg _ (f', DynamicFun _ sv) (((d, argext), arg), argtypes) = do
(arg', _) <- defuncExp arg
let (argtypes', rettype) = dynamicFunType sv argtypes
restype = foldFunType argtypes' (RetType [] rettype)
callret = AppRes restype []
apply_e = mkApply f' [(d, argext, arg')] callret
pure (apply_e, sv)
--
defuncApplyArg _ (_, sv) _ =
error $ "defuncApplyArg: cannot apply StaticVal\n" <> show sv

updateReturn :: AppRes -> Exp -> Exp
updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) =
AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2)
updateReturn _ e = e

defuncApply :: Exp -> NE.NonEmpty ((Diet, Maybe VName), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal)
defuncApply f args appres loc = do
(f', f_sv) <- defuncApplyFunction f (length args)
case f_sv of
IntrinsicSV -> do
args' <- fmap (first Info) <$> traverse (traverse defuncSoacExp) args
let e' = AppExp (Apply f' args' loc) (Info appres)
intrinsicOrHole e'
HoleSV {} -> do
args' <- fmap (first Info) <$> traverse (traverse $ fmap fst . defuncExp) args
let e' = AppExp (Apply f' args' loc) (Info appres)
intrinsicOrHole e'
_ -> do
let fname = liftedName 0 f
(argtypes, _) = unfoldFunType $ typeOf f
fmap (first $ updateReturn appres) $
foldM (defuncApplyArg fname) (f', f_sv) $
NE.zip args $
NE.tails argtypes
where
intrinsicOrHole e' = do
-- If the intrinsic is fully applied, then we are done.
-- Otherwise we need to eta-expand it and recursively
-- defunctionalise. XXX: might it be better to simply eta-expand
-- immediately any time we encounter a non-fully-applied
-- intrinsic?
if null $ fst $ unfoldFunType $ appResType appres
then pure (e', Dynamic $ appResType appres)
else do
(pats, body, tp) <- etaExpand (typeOf e') e'
defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) mempty

-- | Check if a 'StaticVal' and a given application depth corresponds
-- to a fully applied dynamic function.
Expand Down Expand Up @@ -1181,11 +1161,10 @@ matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t
matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType t
matchPatSV pat sv =
error $
"Tried to match pattern "
"Tried to match pattern\n"
++ prettyString pat
++ " with static value "
++ "\n with static value\n"
++ show sv
++ "."

orderZeroSV :: StaticVal -> Bool
orderZeroSV Dynamic {} = True
Expand Down Expand Up @@ -1235,9 +1214,9 @@ updatePat pat (Dynamic t) = updatePat pat (svFromType t)
updatePat pat (HoleSV t _) = updatePat pat (svFromType t)
updatePat pat sv =
error $
"Tried to update pattern "
"Tried to update pattern\n"
++ prettyString pat
++ "to reflect the static value "
++ "\nto reflect the static value\n"
++ show sv

-- | Convert a record (or tuple) type to a record static value. This is used for
Expand Down
11 changes: 5 additions & 6 deletions src/Futhark/Internalise/Exps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1443,14 +1443,13 @@ data Function
deriving (Show)

findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)])
findFuncall (E.Apply f arg (Info (_, argext)) _)
| E.AppExp f_e _ <- f =
let (f_e', args) = findFuncall f_e
in (f_e', args ++ [(arg, argext)])
findFuncall (E.Apply f args _)
| E.Var fname _ _ <- f =
(FunctionName fname, [(arg, argext)])
(FunctionName fname, map onArg $ NE.toList args)
| E.Hole (Info t) loc <- f =
(FunctionHole t loc, [(arg, argext)])
(FunctionHole t loc, map onArg $ NE.toList args)
where
onArg (Info (_, argext), e) = (e, argext)
findFuncall e =
error $ "Invalid function expression in application:\n" ++ prettyString e

Expand Down
10 changes: 6 additions & 4 deletions src/Futhark/Internalise/LiftLambdas.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ replacing v e = local $ \env ->

existentials :: Exp -> S.Set VName
existentials e =
let here = case e of
AppExp (Apply _ _ (Info (_, pdim)) _) (Info res) ->
S.fromList (maybeToList pdim ++ appResExt res)
let onArg (Info (_, pdim), _) =
maybeToList pdim
here = case e of
AppExp (Apply _ args _) (Info res) ->
S.fromList (foldMap onArg args <> appResExt res)
AppExp _ (Info res) ->
S.fromList (appResExt res)
_ ->
Expand Down Expand Up @@ -129,7 +131,7 @@ liftFunction fname tparams params (RetType dims ret) funbody = do
apply f [] = f
apply f (p : rem_ps) =
let inner_ret = AppRes (fromStruct (augType rem_ps)) mempty
inner = AppExp (Apply f (freeVar p) (Info (Observe, Nothing)) mempty) (Info inner_ret)
inner = mkApply f [(Observe, Nothing, freeVar p)] inner_ret
in apply inner rem_ps

transformExp :: Exp -> LiftM Exp
Expand Down
Loading