From 4d173e27dc04a336592e0a43f5139c9cfa6f8386 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 13 Feb 2023 11:06:34 +0100 Subject: [PATCH] Make source language Apply AST node multi-argument. (#1875) This is a deviation from the concrete syntax, but humans tend to think of function calls having multiple arguments. Also, the AST had to keep a lot of useless metadata around to express the results of the intermediate applications. And again, it is related to making #1872 more feasible. --- src/Futhark/Internalise/Defunctionalise.hs | 295 ++++++++++----------- src/Futhark/Internalise/Exps.hs | 11 +- src/Futhark/Internalise/LiftLambdas.hs | 10 +- src/Futhark/Internalise/Monomorphise.hs | 53 ++-- src/Futhark/Util.hs | 24 +- src/Language/Futhark/FreeVars.hs | 2 +- src/Language/Futhark/Interpreter.hs | 13 +- src/Language/Futhark/Parser/Monad.hs | 3 +- src/Language/Futhark/Pretty.hs | 6 +- src/Language/Futhark/Syntax.hs | 42 ++- src/Language/Futhark/Traversals.hs | 8 +- src/Language/Futhark/TypeChecker/Terms.hs | 56 ++-- 12 files changed, 265 insertions(+), 258 deletions(-) diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 36c5d20e05..01a3157c41 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -533,7 +533,8 @@ defuncExp (AppExp (If e1 e2 e3 loc) res) = do (e2', sv) <- defuncExp e2 (e3', _) <- defuncExp e3 pure (AppExp (If e1' e2' e3' loc) res, sv) -defuncExp e@(AppExp Apply {} _) = defuncApply 0 e +defuncExp (AppExp (Apply f args loc) (Info appres)) = + defuncApply f (fmap (first unInfo) args) appres loc defuncExp (Negate e0 loc) = do (e0', sv) <- defuncExp e0 pure (Negate e0' loc, sv) @@ -706,9 +707,8 @@ etaExpand e_t e = do let e' = foldl' ( \e1 (e2, t2, argtypes) -> - AppExp - (Apply e1 e2 (Info (diet t2, Nothing)) mempty) - (Info (AppRes (foldFunType argtypes ret) [])) + mkApply e1 [(diet t2, Nothing, e2)] $ + AppRes (foldFunType argtypes ret) [] ) e $ zip3 vars (map (snd . snd) ps) (drop 1 $ tails $ map snd ps) @@ -807,164 +807,21 @@ unRetType (RetType ext t) = first onDim t onDim (NamedSize d) | qualLeaf d `elem` ext = AnySize Nothing onDim d = d --- | Defunctionalize an application expression at a given depth of application. --- Calls to dynamic (first-order) functions are preserved at much as possible, --- but a new lifted function is created if a dynamic function is only partially --- applied. -defuncApply :: Int -> Exp -> DefM (Exp, StaticVal) -defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do - let (argtypes, _) = unfoldFunType ret - (e1', sv1) <- defuncApply (depth + 1) e1 - case sv1 of - LambdaSV pat e0_t e0 closure_env -> do - (e2', sv2) <- defuncExp e2 - let env' = matchPatSV pat sv2 - dims = mempty - (e0', sv) <- - localNewEnv (env' <> closure_env) $ - defuncExp e0 - - let closure_pat = buildEnvPat dims closure_env - pat' = updatePat pat sv2 - - globals <- asks fst - - -- Lift lambda to top-level function definition. We put in - -- a lot of effort to try to infer the uniqueness attributes - -- of the lifted function, but this is ultimately all a sham - -- and a hack. There is some piece we're missing. - let params = [closure_pat, pat'] - params_for_rettype = params ++ svParams sv1 ++ svParams sv2 - svParams (LambdaSV sv_pat _ _ _) = [sv_pat] - svParams _ = [] - lifted_rettype = buildRetType closure_env params_for_rettype (unRetType e0_t) $ typeOf e0' - - already_bound = - globals - <> S.fromList dims - <> S.map identName (foldMap patIdents params) - - more_dims = - S.toList $ - S.filter (`S.notMember` already_bound) $ - foldMap patternArraySizes params - - -- Embed some information about the original function - -- into the name of the lifted function, to make the - -- result slightly more human-readable. - liftedName i (Var f _ _) = - "defunc_" ++ show i ++ "_" ++ baseString (qualLeaf f) - liftedName i (AppExp (Apply f _ _ _) _) = - liftedName (i + 1) f - liftedName _ _ = "defunc" - - -- Ensure that no parameter sizes are AnySize. The internaliser - -- expects this. This is easy, because they are all - -- first-order. - let bound_sizes = S.fromList (dims <> more_dims) <> globals - (missing_dims, params') <- sizesForAll bound_sizes params - - fname <- newNameFromString $ liftedName (0 :: Int) e1 - liftValDec - fname - (RetType [] $ toStruct lifted_rettype) - (dims ++ more_dims ++ missing_dims) - params' - e0' - - let t1 = toStruct $ typeOf e1' - t2 = toStruct $ typeOf e2' - d1 = Observe - d2 = Observe - fname' = qualName fname - fname'' = - Var - fname' - ( Info - ( Scalar . Arrow mempty Unnamed d1 t1 . RetType [] $ - Scalar . Arrow mempty Unnamed d2 t2 $ - RetType [] lifted_rettype - ) - ) - loc - - callret = AppRes (combineTypeShapes ret lifted_rettype) ext - - innercallret = - AppRes - (Scalar $ Arrow mempty Unnamed d2 t2 $ RetType [] lifted_rettype) - [] - - pure - ( AppExp - ( Apply - ( AppExp - (Apply fname'' e1' (Info (Observe, Nothing)) loc) - (Info innercallret) - ) - e2' - d - loc - ) - (Info callret), - sv - ) - - -- If e1 is a dynamic function, we just leave the application in place, - -- but we update the types since it may be partially applied or return - -- a higher-order term. - DynamicFun _ sv -> do - (e2', _) <- defuncExp e2 - let (argtypes', rettype) = dynamicFunType sv argtypes - restype = foldFunType argtypes' (RetType [] rettype) `setAliases` aliases ret - callret = AppRes (combineTypeShapes ret restype) ext - apply_e = AppExp (Apply e1' e2' d loc) (Info callret) - pure (apply_e, sv) - -- Propagate the 'IntrinsicsSV' until we reach the outermost application, - -- where we construct a dynamic static value with the appropriate type. - IntrinsicSV -> do - e2' <- defuncSoacExp e2 - let e' = AppExp (Apply e1' e2' d loc) t - intrinsicOrHole argtypes e' sv1 - HoleSV {} -> do - (e2', _) <- defuncExp e2 - let e' = AppExp (Apply e1' e2' d loc) t - intrinsicOrHole argtypes e' sv1 - _ -> - error $ - "Application of an expression\n" - ++ prettyString e1 - ++ "\nthat is neither a static lambda " - ++ "nor a dynamic function, but has static value:\n" - ++ show sv1 - where - intrinsicOrHole argtypes e' sv - | depth == 0 = - -- If the intrinsic is fully applied, then we are done. - -- Otherwise we need to eta-expand it and recursively - -- defunctionalise. XXX: might it be better to simply - -- eta-expand immediately any time we encounter a - -- non-fully-applied intrinsic? - if null argtypes - then pure (e', Dynamic $ typeOf e) - else do - (pats, body, tp) <- etaExpand (typeOf e') e' - defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) mempty - | otherwise = pure (e', sv) -defuncApply depth e@(Var qn (Info t) loc) = do +defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal) +defuncApplyFunction e@(Var qn (Info t) loc) num_args = do let (argtypes, _) = unfoldFunType t sv <- lookupVar (toStruct t) (qualLeaf qn) case sv of DynamicFun _ _ - | fullyApplied sv depth -> do + | fullyApplied sv num_args -> do -- We still need to update the types in case the dynamic -- function returns a higher-order term. let (argtypes', rettype) = dynamicFunType sv argtypes pure (Var qn (Info (foldFunType argtypes' $ RetType [] rettype)) loc, sv) | otherwise -> do fname <- newVName $ "dyn_" <> baseString (qualLeaf qn) - let (pats, e0, sv') = liftDynFun (prettyString qn) sv depth + let (pats, e0, sv') = liftDynFun (prettyString qn) sv num_args (argtypes', rettype) = dynamicFunType sv' argtypes dims' = mempty @@ -985,8 +842,131 @@ defuncApply depth e@(Var qn (Info t) loc) = do ) IntrinsicSV -> pure (e, IntrinsicSV) _ -> pure (Var qn (Info (typeFromSV sv)) loc, sv) -defuncApply depth (Parens e _) = defuncApply depth e -defuncApply _ expr = defuncExp expr +defuncApplyFunction e _ = defuncExp e + +-- Embed some information about the original function +-- into the name of the lifted function, to make the +-- result slightly more human-readable. +liftedName :: Int -> Exp -> String +liftedName i (Var f _ _) = + "defunc_" ++ show i ++ "_" ++ baseString (qualLeaf f) +liftedName i (AppExp (Apply f _ _) _) = + liftedName (i + 1) f +liftedName _ _ = "defunc" + +defuncApplyArg :: + String -> + (Exp, StaticVal) -> + (((Diet, Maybe VName), Exp), [(Diet, StructType)]) -> + DefM (Exp, StaticVal) +defuncApplyArg fname_s (f', f_sv@(LambdaSV pat lam_e_t lam_e closure_env)) (((d, argext), arg), _) = do + (arg', arg_sv) <- defuncExp arg + let env' = matchPatSV pat arg_sv + dims = mempty + (lam_e', sv) <- + localNewEnv (env' <> closure_env) $ + defuncExp lam_e + + let closure_pat = buildEnvPat dims closure_env + pat' = updatePat pat arg_sv + + globals <- asks fst + + -- Lift lambda to top-level function definition. We put in + -- a lot of effort to try to infer the uniqueness attributes + -- of the lifted function, but this is ultimately all a sham + -- and a hack. There is some piece we're missing. + let params = [closure_pat, pat'] + params_for_rettype = params ++ svParams f_sv ++ svParams arg_sv + svParams (LambdaSV sv_pat _ _ _) = [sv_pat] + svParams _ = [] + lifted_rettype = buildRetType closure_env params_for_rettype (unRetType lam_e_t) $ typeOf lam_e' + + already_bound = + globals + <> S.fromList dims + <> S.map identName (foldMap patIdents params) + + more_dims = + S.toList $ + S.filter (`S.notMember` already_bound) $ + foldMap patternArraySizes params + + -- Ensure that no parameter sizes are AnySize. The internaliser + -- expects this. This is easy, because they are all + -- first-order. + let bound_sizes = S.fromList (dims <> more_dims) <> globals + (missing_dims, params') <- sizesForAll bound_sizes params + + fname <- newNameFromString fname_s + liftValDec + fname + (RetType [] $ toStruct lifted_rettype) + (dims ++ more_dims ++ missing_dims) + params' + lam_e' + + let f_t = toStruct $ typeOf f' + arg_t = toStruct $ typeOf arg' + d1 = Observe + fname_t = foldFunType [(d1, f_t), (d, arg_t)] $ RetType [] lifted_rettype + fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) + callret = AppRes lifted_rettype [] + + pure + ( mkApply fname' [(Observe, Nothing, f'), (Observe, argext, arg')] callret, + sv + ) +-- If 'f' is a dynamic function, we just leave the application in +-- place, but we update the types since it may be partially +-- applied or return a higher-order value. +defuncApplyArg _ (f', DynamicFun _ sv) (((d, argext), arg), argtypes) = do + (arg', _) <- defuncExp arg + let (argtypes', rettype) = dynamicFunType sv argtypes + restype = foldFunType argtypes' (RetType [] rettype) + callret = AppRes restype [] + apply_e = mkApply f' [(d, argext, arg')] callret + pure (apply_e, sv) +-- +defuncApplyArg _ (_, sv) _ = + error $ "defuncApplyArg: cannot apply StaticVal\n" <> show sv + +updateReturn :: AppRes -> Exp -> Exp +updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = + AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) +updateReturn _ e = e + +defuncApply :: Exp -> NE.NonEmpty ((Diet, Maybe VName), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) +defuncApply f args appres loc = do + (f', f_sv) <- defuncApplyFunction f (length args) + case f_sv of + IntrinsicSV -> do + args' <- fmap (first Info) <$> traverse (traverse defuncSoacExp) args + let e' = AppExp (Apply f' args' loc) (Info appres) + intrinsicOrHole e' + HoleSV {} -> do + args' <- fmap (first Info) <$> traverse (traverse $ fmap fst . defuncExp) args + let e' = AppExp (Apply f' args' loc) (Info appres) + intrinsicOrHole e' + _ -> do + let fname = liftedName 0 f + (argtypes, _) = unfoldFunType $ typeOf f + fmap (first $ updateReturn appres) $ + foldM (defuncApplyArg fname) (f', f_sv) $ + NE.zip args $ + NE.tails argtypes + where + intrinsicOrHole e' = do + -- If the intrinsic is fully applied, then we are done. + -- Otherwise we need to eta-expand it and recursively + -- defunctionalise. XXX: might it be better to simply eta-expand + -- immediately any time we encounter a non-fully-applied + -- intrinsic? + if null $ fst $ unfoldFunType $ appResType appres + then pure (e', Dynamic $ appResType appres) + else do + (pats, body, tp) <- etaExpand (typeOf e') e' + defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) mempty -- | Check if a 'StaticVal' and a given application depth corresponds -- to a fully applied dynamic function. @@ -1181,11 +1161,10 @@ matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType t matchPatSV pat sv = error $ - "Tried to match pattern " + "Tried to match pattern\n" ++ prettyString pat - ++ " with static value " + ++ "\n with static value\n" ++ show sv - ++ "." orderZeroSV :: StaticVal -> Bool orderZeroSV Dynamic {} = True @@ -1235,9 +1214,9 @@ updatePat pat (Dynamic t) = updatePat pat (svFromType t) updatePat pat (HoleSV t _) = updatePat pat (svFromType t) updatePat pat sv = error $ - "Tried to update pattern " + "Tried to update pattern\n" ++ prettyString pat - ++ "to reflect the static value " + ++ "\nto reflect the static value\n" ++ show sv -- | Convert a record (or tuple) type to a record static value. This is used for diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 031c304cc0..f9767615bb 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1443,14 +1443,13 @@ data Function deriving (Show) findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)]) -findFuncall (E.Apply f arg (Info (_, argext)) _) - | E.AppExp f_e _ <- f = - let (f_e', args) = findFuncall f_e - in (f_e', args ++ [(arg, argext)]) +findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = - (FunctionName fname, [(arg, argext)]) + (FunctionName fname, map onArg $ NE.toList args) | E.Hole (Info t) loc <- f = - (FunctionHole t loc, [(arg, argext)]) + (FunctionHole t loc, map onArg $ NE.toList args) + where + onArg (Info (_, argext), e) = (e, argext) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e diff --git a/src/Futhark/Internalise/LiftLambdas.hs b/src/Futhark/Internalise/LiftLambdas.hs index da53916eff..44f3282709 100644 --- a/src/Futhark/Internalise/LiftLambdas.hs +++ b/src/Futhark/Internalise/LiftLambdas.hs @@ -55,9 +55,11 @@ replacing v e = local $ \env -> existentials :: Exp -> S.Set VName existentials e = - let here = case e of - AppExp (Apply _ _ (Info (_, pdim)) _) (Info res) -> - S.fromList (maybeToList pdim ++ appResExt res) + let onArg (Info (_, pdim), _) = + maybeToList pdim + here = case e of + AppExp (Apply _ args _) (Info res) -> + S.fromList (foldMap onArg args <> appResExt res) AppExp _ (Info res) -> S.fromList (appResExt res) _ -> @@ -129,7 +131,7 @@ liftFunction fname tparams params (RetType dims ret) funbody = do apply f [] = f apply f (p : rem_ps) = let inner_ret = AppRes (fromStruct (augType rem_ps)) mempty - inner = AppExp (Apply f (freeVar p) (Info (Observe, Nothing)) mempty) (Info inner_ret) + inner = mkApply f [(Observe, Nothing, freeVar p)] inner_ret in apply inner rem_ps transformExp :: Exp -> LiftM Exp diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index 643639473a..b92eeb5b04 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -219,9 +219,10 @@ transformFName loc fname t applySizeArg (i, f) size_arg = ( i - 1, - AppExp - (Apply f size_arg (Info (Observe, Nothing)) loc) - (Info $ AppRes (foldFunType (replicate i (Observe, i64)) (RetType [] (fromStruct t))) []) + mkApply + f + [(Observe, Nothing, size_arg)] + (AppRes (foldFunType (replicate i (Observe, i64)) (RetType [] (fromStruct t))) []) ) applySizeArgs fname' t' size_args = @@ -310,8 +311,14 @@ transformAppExp (LetFun fname (tparams, params, retdecl, Info ret, body) e loc) <*> pure (Info res) transformAppExp (If e1 e2 e3 loc) res = AppExp <$> (If <$> transformExp e1 <*> transformExp e2 <*> transformExp e3 <*> pure loc) <*> pure (Info res) -transformAppExp (Apply e1 e2 d loc) res = - AppExp <$> (Apply <$> transformExp e1 <*> transformExp e2 <*> pure d <*> pure loc) <*> pure (Info res) +transformAppExp (Apply fe args loc) res = + AppExp + <$> ( Apply + <$> transformExp fe + <*> traverse (traverse transformExp) args + <*> pure loc + ) + <*> pure (Info res) transformAppExp (DoLoop sparams pat e1 form e3 loc) res = do e1' <- transformExp e1 form' <- case form of @@ -357,17 +364,10 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) (AppRes ret ex (Info (AppRes ret mempty)) where applyOp fname' x y = - AppExp - ( Apply - ( AppExp - (Apply fname' x (Info (Observe, snd (unInfo d1))) loc) - (Info $ AppRes ret mempty) - ) - y - (Info (Observe, snd (unInfo d2))) - loc - ) - (Info (AppRes ret ext)) + mkApply + (mkApply fname' [(Observe, snd (unInfo d1), x)] (AppRes ret mempty)) + [(Observe, snd (unInfo d2), y)] + (AppRes ret ext) makeVarParam arg = do let argtype = typeOf arg @@ -533,14 +533,10 @@ desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (Ret (v1, wrap_left, e1, p1) <- makeVarParam e_left $ fromStruct xtype (v2, wrap_right, e2, p2) <- makeVarParam e_right $ fromStruct ytype let apply_left = - AppExp - ( Apply - op - e1 - (Info (Observe, xext)) - loc - ) - (Info $ AppRes (Scalar $ Arrow mempty yp Observe ytype (RetType [] t)) []) + mkApply + op + [(Observe, xext, e1)] + (AppRes (Scalar $ Arrow mempty yp Observe ytype (RetType [] t)) []) rettype' = let onDim (NamedSize d) | Named p <- xp, qualLeaf d == p = NamedSize $ qualName v1 @@ -548,14 +544,7 @@ desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (Ret onDim d = d in first onDim rettype body = - AppExp - ( Apply - apply_left - e2 - (Info (Observe, yext)) - loc - ) - (Info $ AppRes rettype' retext) + mkApply apply_left [(Observe, yext, e2)] (AppRes rettype' retext) rettype'' = toStruct rettype' pure $ wrap_left $ diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index abc6eb923d..ed00c64d6e 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -57,6 +57,7 @@ import Control.Arrow (first) import Control.Concurrent import Control.Exception import Control.Monad +import Control.Monad.State import Crypto.Hash.MD5 as MD5 import Data.ByteString qualified as BS import Data.ByteString.Base16 qualified as Base16 @@ -96,18 +97,23 @@ nubByOrd cmp = map NE.head . NE.groupBy eq . sortBy cmp where eq x y = cmp x y == EQ --- | Like 'Data.Traversable.mapAccumL', but monadic. +-- | Like 'Data.Traversable.mapAccumL', but monadic and generalised to +-- any 'Traversable'. mapAccumLM :: - Monad m => + (Monad m, Traversable t) => (acc -> x -> m (acc, y)) -> acc -> - [x] -> - m (acc, [y]) -mapAccumLM _ acc [] = pure (acc, []) -mapAccumLM f acc (x : xs) = do - (acc', x') <- f acc x - (acc'', xs') <- mapAccumLM f acc' xs - pure (acc'', x' : xs') + t x -> + m (acc, t y) +mapAccumLM op initial l = do + (l', acc) <- runStateT (traverse f l) initial + pure (acc, l') + where + f x = do + acc <- get + (acc', y) <- lift $ op acc x + put acc' + pure y -- | @chunk n a@ splits @a@ into @n@-size-chunks. If the length of -- @a@ is not divisible by @n@, the last chunk will have fewer than diff --git a/src/Language/Futhark/FreeVars.hs b/src/Language/Futhark/FreeVars.hs index 76c89d73dd..65cc3a46cd 100644 --- a/src/Language/Futhark/FreeVars.hs +++ b/src/Language/Futhark/FreeVars.hs @@ -76,7 +76,7 @@ freeInExp expr = case expr of ) <> (freeInExp e2 `freeWithout` S.singleton vn) AppExp (If e1 e2 e3 _) _ -> freeInExp e1 <> freeInExp e2 <> freeInExp e3 - AppExp (Apply e1 e2 _ _) _ -> freeInExp e1 <> freeInExp e2 + AppExp (Apply f args _) _ -> freeInExp f <> foldMap (freeInExp . snd) args Negate e _ -> freeInExp e Not e _ -> freeInExp e Lambda pats e0 _ (Info (_, RetType dims t)) _ -> diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index d430061039..bc811e41d7 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -778,12 +778,15 @@ evalAppExp evalAppExp env _ (If cond e1 e2 _) = do cond' <- asBool <$> eval env cond if cond' then eval env e1 else eval env e2 -evalAppExp env _ (Apply f x (Info (_, ext)) loc) = do - -- It is important that 'x' is evaluated first in order to bring any - -- sizes into scope that may be used in the type of 'f'. - x' <- evalArg env x ext +evalAppExp env _ (Apply f args loc) = do + -- It is important that 'arguments' are evaluated in reverse order + -- in order to bring any sizes into scope that may be used in the + -- type of the functions. + args' <- reverse <$> mapM evalArg' (reverse $ NE.toList args) f' <- eval env f - apply loc env f' x' + foldM (apply loc env) f' args' + where + evalArg' (Info (_, ext), x) = evalArg env x ext evalAppExp env _ (Index e is loc) = do is' <- mapM (evalDimIndex env) is arr <- eval env e diff --git a/src/Language/Futhark/Parser/Monad.hs b/src/Language/Futhark/Parser/Monad.hs index da2f2d47df..0dd15f713d 100644 --- a/src/Language/Futhark/Parser/Monad.hs +++ b/src/Language/Futhark/Parser/Monad.hs @@ -140,8 +140,7 @@ applyExp es = <+> align (pretty index) where index = AppExp (Index e (is ++ map DimFix xs) xloc) NoInfo - op f x = - pure $ AppExp (Apply f x NoInfo (srcspan f x)) NoInfo + op f x = pure $ mkApplyUT f x patternExp :: UncheckedPat -> ParserMonad UncheckedExp patternExp (Id v _ loc) = pure $ Var (qualName v) NoInfo loc diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 210afc3fe6..a1cb1f5ec1 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -302,8 +302,10 @@ prettyAppExp _ (If c t f _) = <+> align (pretty t) "else" <+> align (pretty f) -prettyAppExp p (Apply f arg _ _) = - parensIf (p >= 10) $ prettyExp 0 f <+> prettyExp 10 arg +prettyAppExp p (Apply f args _) = + parensIf (p >= 10) $ + prettyExp 0 f + <+> hsep (map (prettyExp 10 . snd) $ NE.toList args) instance (Eq vn, IsName vn, Annot f) => Pretty (AppExpBase f vn) where pretty = prettyAppExp (-1) diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index 6ab3486366..4af401e045 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -91,6 +91,8 @@ module Language.Futhark.Syntax Alias (..), Aliasing, QualName (..), + mkApply, + mkApplyUT, ) where @@ -622,14 +624,18 @@ instance Located (SizeBinder vn) where -- need, so we can pretend that an application expression was really -- bound to a name. data AppExpBase f vn - = -- | The @Maybe VName@ is a possible existential size that is - -- instantiated by this argument. May have duplicates across the - -- program, but they will all produce the same value (the - -- expressions will be identical). + = -- | Function application. Parts of the compiler expects that the + -- function expression is never itself an 'Apply'. Use the + -- 'mkApply' function to maintain this invariant, rather than + -- constructing 'Apply' directly. + -- + -- The @Maybe VNames@ are existential sizes generated by this + -- argumnet. May have duplicates across the program, but they + -- will all produce the same value (the expressions will be + -- identical). Apply (ExpBase f vn) - (ExpBase f vn) - (f (Diet, Maybe VName)) + (NE.NonEmpty (f (Diet, Maybe VName), ExpBase f vn)) SrcLoc | -- | Size coercion: @e :> t@. Coerce (ExpBase f vn) (TypeExp vn) SrcLoc @@ -696,7 +702,7 @@ instance Located (AppExpBase f vn) where locOf (BinOp _ _ _ _ loc) = locOf loc locOf (If _ _ _ loc) = locOf loc locOf (Coerce _ _ loc) = locOf loc - locOf (Apply _ _ _ loc) = locOf loc + locOf (Apply _ _ loc) = locOf loc locOf (LetPat _ _ _ _ loc) = locOf loc locOf (LetFun _ _ _ loc) = locOf loc locOf (LetWith _ _ _ _ _ loc) = locOf loc @@ -1229,6 +1235,28 @@ deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) +-- | Construct an 'Apply' node, with type information. +mkApply :: ExpBase Info vn -> [(Diet, Maybe VName, ExpBase Info vn)] -> AppRes -> ExpBase Info vn +mkApply f args (AppRes t ext) + | Just args' <- NE.nonEmpty $ map onArg args = + case f of + (AppExp (Apply f' f_args loc) (Info (AppRes _ f_ext))) -> + AppExp + (Apply f' (f_args <> args') (srcspan loc $ snd $ NE.last args')) + (Info $ AppRes t $ f_ext <> ext) + _ -> + AppExp (Apply f args' (srcspan f $ snd $ NE.last args')) (Info (AppRes t ext)) + | otherwise = f + where + onArg (d, v, x) = (Info (d, v), x) + +-- | Construct an 'Apply' node, without type information. +mkApplyUT :: ExpBase NoInfo vn -> ExpBase NoInfo vn -> ExpBase NoInfo vn +mkApplyUT (AppExp (Apply f args loc) _) x = + AppExp (Apply f (args <> NE.singleton (NoInfo, x)) (srcspan loc x)) NoInfo +mkApplyUT f x = + AppExp (Apply f (NE.singleton (NoInfo, x)) (srcspan f x)) NoInfo + --- Some prettyprinting definitions are here because we need them in --- the Attributes module. diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index 1ee8f04032..711c306d0f 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -76,8 +76,8 @@ instance ASTMappable (AppExpBase Info VName) where If <$> mapOnExp tv c <*> mapOnExp tv texp <*> mapOnExp tv fexp <*> pure loc astMap tv (Match e cases loc) = Match <$> mapOnExp tv e <*> astMap tv cases <*> pure loc - astMap tv (Apply f arg d loc) = - Apply <$> mapOnExp tv f <*> mapOnExp tv arg <*> pure d <*> pure loc + astMap tv (Apply f args loc) = + Apply <$> mapOnExp tv f <*> traverse (traverse $ mapOnExp tv) args <*> pure loc astMap tv (LetPat sizes pat e body loc) = LetPat <$> astMap tv sizes <*> astMap tv pat <*> mapOnExp tv e <*> mapOnExp tv body <*> pure loc astMap tv (LetFun name (fparams, params, ret, t, e) body loc) = @@ -497,8 +497,8 @@ bareExp (AppExp appexp _) = BinOp fname NoInfo (bareExp x, NoInfo) (bareExp y, NoInfo) loc If c texp fexp loc -> If (bareExp c) (bareExp texp) (bareExp fexp) loc - Apply f arg _ loc -> - Apply (bareExp f) (bareExp arg) NoInfo loc + Apply f args loc -> + Apply (bareExp f) (fmap ((NoInfo,) . bareExp . snd) args) loc LetPat sizes pat e body loc -> LetPat sizes (barePat pat) (bareExp e) (bareExp body) loc LetFun name (fparams, params, ret, _, e) body loc -> diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 44cc4d25fe..a6c0c132f6 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -21,6 +21,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Futhark.Util (mapAccumLM) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Primitive (intByteSize) @@ -204,31 +205,6 @@ unscopeType tloc unscoped t = do unAlias (AliasBound v) | v `M.member` unscoped = AliasFree v unAlias a = a --- 'checkApplyExp' is like 'checkExp', but tries to find the "root --- function", for better error messages. -checkApplyExp :: UncheckedExp -> TermTypeM (Exp, ApplyOp) -checkApplyExp (AppExp (Apply e1 e2 _ loc) _) = do - arg <- checkArg e2 - (e1', (fname, i)) <- checkApplyExp e1 - t <- expType e1' - (d1, _, rt, argext, exts) <- checkApply loc (fname, i) t arg - pure - ( AppExp - (Apply e1' (argExp arg) (Info (d1, argext)) loc) - (Info $ AppRes rt exts), - (fname, i + 1) - ) -checkApplyExp e = do - e' <- checkExp e - pure - ( e', - ( case e' of - Var qn _ _ -> Just qn - _ -> Nothing, - 0 - ) - ) - checkExp :: UncheckedExp -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc @@ -437,7 +413,23 @@ checkExp (Negate arg loc) = do checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc -checkExp e@(AppExp Apply {} _) = fst <$> checkApplyExp e +checkExp (AppExp (Apply fe args loc) NoInfo) = do + fe' <- checkExp fe + args' <- mapM (checkArg . snd) args + t <- expType fe' + let fname = + case fe' of + Var v _ _ -> Just v + _ -> Nothing + ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' + pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts + where + onArg fname (i, all_exts, t) arg' = do + (d1, _, rt, argext, exts) <- checkApply loc (fname, i) t arg' + pure + ( (i + 1, all_exts <> exts, rt), + (Info (d1, argext), argExp arg') + ) checkExp (AppExp (LetPat sizes pat e body loc) _) = sequentially (checkExp e) $ \e' e_occs -> do -- Not technically an ascription, but we want the pattern to have @@ -1041,9 +1033,17 @@ causalityCheck binding_body = do onExp known e@(AppExp (LetPat _ _ bindee_e body_e _) (Info res)) = do sequencePoint known bindee_e body_e $ appResExt res pure e - onExp known e@(AppExp (Apply f arg (Info (_, p)) _) (Info res)) = do - sequencePoint known arg f $ maybeToList p ++ appResExt res + onExp known e@(AppExp (Apply f args _) (Info res)) = do + seqArgs known $ reverse $ NE.toList args pure e + where + seqArgs known' [] = do + void $ onExp known' f + modify (S.fromList (appResExt res) <>) + seqArgs known' ((Info (_, p), x) : xs) = do + new_known <- lift $ execStateT (onExp known' x) mempty + void $ seqArgs (new_known <> known') xs + modify ((new_known <> S.fromList (maybeToList p)) <>) onExp known e@(AppExp (BinOp (f, floc) ft (x, Info (_, xp)) (y, Info (_, yp)) _) (Info res)) = do