From 2b94ea17a4f8e15a66b05e061db63bb228452fdc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 12 Feb 2023 15:08:02 +0100 Subject: [PATCH] Make consumption an effect on functions, rather than types. (#1873) This is a breaking change, because until now we allowed functions like def f (a: *[]i32, b: []i32) = ... where we could then pass in a tuple where in an application `f (x,y)` the value `x` would be consumed, but not `y`. However, this became increasingly difficult to support as the language grew (and frankly, it was always buggy). With this commit, the syntax above is still permitted, but it is interpreted as def f ((a,b): *([]i32, []i32)) = ... i.e. the single tuple argument is consumed *as a whole*. Long term we can also consider amending the syntax or warning about cases where it is misleading, but that is less urgent. I've wanted to make this simplification for a long time, but I always hit various snags. Today I managed to make it work, and the next step will be cleaning up the notion of "uniqueness" in return types as well (it should be the more general notion of "aliases"). --- CHANGELOG.md | 3 + docs/error-index.rst | 35 ++ docs/language-reference.rst | 28 +- prelude/ad.fut | 4 +- prelude/array.fut | 6 +- prelude/soacs.fut | 30 +- prelude/zip.fut | 4 +- src/Futhark/Doc/Generator.hs | 20 +- src/Futhark/Internalise/Defunctionalise.hs | 52 +-- src/Futhark/Internalise/Exps.hs | 50 +-- src/Futhark/Internalise/Monomorphise.hs | 27 +- src/Language/Futhark/FreeVars.hs | 2 +- src/Language/Futhark/Interpreter.hs | 182 ++++---- src/Language/Futhark/Pretty.hs | 19 +- src/Language/Futhark/Prop.hs | 399 ++++++++++-------- src/Language/Futhark/Query.hs | 4 +- src/Language/Futhark/Syntax.hs | 24 +- src/Language/Futhark/Traversals.hs | 3 +- src/Language/Futhark/TypeChecker.hs | 14 +- src/Language/Futhark/TypeChecker/Modules.hs | 4 +- src/Language/Futhark/TypeChecker/Monad.hs | 4 +- src/Language/Futhark/TypeChecker/Terms.hs | 99 ++--- .../Futhark/TypeChecker/Terms/Monad.hs | 17 +- src/Language/Futhark/TypeChecker/Types.hs | 34 +- src/Language/Futhark/TypeChecker/Unify.hs | 12 +- tests/accs/intrinsics.fut | 6 +- tests/implicit_method.fut | 5 +- tests/{issue-1774.fut => issue1774.fut} | 0 tests/migration/intrinsics.fut | 18 +- tests/slice-lmads/intrinsics.fut | 12 +- tests/uniqueness/uniqueness-error23.fut | 7 +- unittests/Language/Futhark/SyntaxTests.hs | 15 +- 32 files changed, 610 insertions(+), 529 deletions(-) rename tests/{issue-1774.fut => issue1774.fut} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95780301d6..dff6b1b5c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Changed +* If part of a function parameter is marked as consuming ("unique"), + the *entire* parameter is now marked as consuming. + ### Fixed * A somewhat obscure simplification rule could mess up use of memory. diff --git a/docs/error-index.rst b/docs/error-index.rst index ab75d6dccd..d985650c2d 100644 --- a/docs/error-index.rst +++ b/docs/error-index.rst @@ -149,6 +149,41 @@ inserting copies to break the aliasing: def main (xs: *[]i32) : (*[]i32, *[]i32) = (xs, copy xs) +.. _self-aliasing-arg: + +"Argument passed for consuming parameter is self-aliased." +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Caused by programs like the following: + +.. code-block:: futhark + + def g (t: *([]i64, []i64)) = 0 + + def f n = + let x = iota n + in g (x,x) + +The function ``g`` expects to consume two separate ``[]i64`` arrays, +but ``f`` passes it a tuple containing two references to the same +physical array. This is not allowed, as ``g`` must be allowed to +assume that components of consuming record- or tuple parameters have +no internal aliases. We can fix this by inserting copies to break the +aliasing: + +.. code-block:: futhark + + def f n = + let x = iota n + in g (copy (x,x)) + +Alternative, we could duplicate the expression producing the array: + +.. code-block:: futhark + + def f n = + g (iota n, iota n)) + .. _consuming-parameter: "Consuming parameter passed non-unique argument" diff --git a/docs/language-reference.rst b/docs/language-reference.rst index ae6efbcc5e..19f3942e79 100644 --- a/docs/language-reference.rst +++ b/docs/language-reference.rst @@ -1434,15 +1434,27 @@ prefixing it with an asterisk. For a return type, we can mark it as def modify (a: *[]i32) (i: i32) (x: i32): *[]i32 = a with [i] = a[i] + x +A parameter that is not consuming is called *observing*. In the +parameter declaration ``a: *[i32]``, the asterisk means that the +function ``modify`` has been given "ownership" of the array ``a``, +meaning that any caller of ``modify`` will never reference array ``a`` +after the call again. This allows the ``with`` expression to perform +an in-place update. After a call ``modify a i x``, neither ``a`` or +any variable that *aliases* ``a`` may be used on any following +execution path. + +If an asterisk is present at *any point* inside a tuple parameter +type, the parameter as a whole is considered consuming. For example:: + + def consumes_both ((a,b): (*[]i32,[]i32)) = ... + +This is usually not desirable behaviour. Use multiple parameters +instead:: + + def consumes_first_arg (a: *[]i32) (b: []i32) = ... + For bulk in-place updates with multiple values, use the ``scatter`` -function in the basis library. In the parameter declaration ``a: -*[i32]``, the asterisk means that the function ``modify`` has been -given "ownership" of the array ``a``, meaning that any caller of -``modify`` will never reference array ``a`` after the call again. -This allows the ``with`` expression to perform an in-place update. - -After a call ``modify a i x``, neither ``a`` or any variable that -*aliases* ``a`` may be used on any following execution path. +function in the basis library. Alias Analysis ~~~~~~~~~~~~~~ diff --git a/prelude/ad.fut b/prelude/ad.fut index 91eb24ffbc..251e47a21d 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -7,12 +7,12 @@ -- | Jacobian-Vector Product ("forward mode"), producing also the -- primal result as the first element of the result tuple. let jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) = - intrinsics.jvp2 (f, x, x') + intrinsics.jvp2 f x x' -- | Vector-Jacobian Product ("reverse mode"), producing also the -- primal result as the first element of the result tuple. let vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = - intrinsics.vjp2 (f, x, y') + intrinsics.vjp2 f x y' -- | Jacobian-Vector Product ("forward mode"). let jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = diff --git a/prelude/array.fut b/prelude/array.fut index 7c880af8f1..b58046a4d7 100644 --- a/prelude/array.fut +++ b/prelude/array.fut @@ -62,7 +62,7 @@ def reverse [n] 't (x: [n]t): [n]t = x[::-1] -- **Work:** O(n). -- -- **Span:** O(1). -def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = intrinsics.concat (xs, ys) +def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = intrinsics.concat xs ys -- | An old-fashioned way of saying `++`. def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = xs ++ ys @@ -83,7 +83,7 @@ def concat_to [n] [m] 't (k: i64) (xs: [n]t) (ys: [m]t): *[k]t = xs ++ ys :> [k] -- -- Note: In most cases, `rotate` will be fused with subsequent -- operations such as `map`, in which case it is free. -def rotate [n] 't (r: i64) (xs: [n]t): [n]t = intrinsics.rotate (r, xs) +def rotate [n] 't (r: i64) (xs: [n]t): [n]t = intrinsics.rotate r xs -- | Construct an array of consecutive integers of the given length, -- starting at 0. @@ -143,7 +143,7 @@ def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): []t = -- -- **Complexity:** O(1). def unflatten [p] 't (n: i64) (m: i64) (xs: [p]t): [n][m]t = - intrinsics.unflatten (n, m, xs) :> [n][m]t + intrinsics.unflatten n m xs :> [n][m]t -- | Like `unflatten`, but produces three dimensions. def unflatten_3d [p] 't (n: i64) (m: i64) (l: i64) (xs: [p]t): [n][m][l]t = diff --git a/prelude/soacs.fut b/prelude/soacs.fut index 67a43b4fb6..fe7f632ccf 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -48,7 +48,7 @@ import "zip" -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map (f, as) + intrinsics.map f as -- | Apply the given function to each element of a single array. -- @@ -104,7 +104,7 @@ def map5 'a 'b 'c 'd 'e [n] 'x (f: a -> b -> c -> d -> e -> x) (as: [n]a) (bs: [ -- Note that the complexity implies that parallelism in the combining -- operator will *not* be exploited. def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = - intrinsics.reduce (op, ne, as) + intrinsics.reduce op ne as -- | As `reduce`, but the operator must also be commutative. This is -- potentially faster than `reduce`. For simple built-in operators, @@ -115,7 +115,7 @@ def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = -- -- **Span:** *O(log(n) ✕ W(op))* def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = - intrinsics.reduce_comm (op, ne, as) + intrinsics.reduce_comm op ne as -- | `h = hist op ne k is as` computes a generalised `k`-bin histogram -- `h`, such that `h[i]` is the sum of those values `as[j]` for which @@ -130,7 +130,7 @@ def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = -- -- In practice, linear span only occurs if *k* is also very large. def hist 'a [n] (op: a -> a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k]a = - intrinsics.hist_1d (1, map (\_ -> ne) (0..1.. ne) (0..1.. a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k -- -- In practice, linear span only occurs if *k* is also very large. def reduce_by_index 'a [k] [n] (dest : *[k]a) (f : a -> a -> a) (ne : a) (is : [n]i64) (as : [n]a) : *[k]a = - intrinsics.hist_1d (1, dest, f, ne, is, as) + intrinsics.hist_1d 1 dest f ne is as -- | As `reduce_by_index`, but with two-dimensional indexes. def reduce_by_index_2d 'a [k] [n] [m] (dest : *[k][m]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64)) (as : [n]a) : *[k][m]a = - intrinsics.hist_2d (1, dest, f, ne, is, as) + intrinsics.hist_2d 1 dest f ne is as -- | As `reduce_by_index`, but with three-dimensional indexes. def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64,i64)) (as : [n]a) : *[k][m][l]a = - intrinsics.hist_3d (1, dest, f, ne, is, as) + intrinsics.hist_3d 1 dest f ne is as -- | Inclusive prefix scan. Has the same caveats with respect to -- associativity and complexity as `reduce`. @@ -160,7 +160,7 @@ def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a) -- -- **Span:** *O(log(n) ✕ W(op))* def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a = - intrinsics.scan (op, ne, as) + intrinsics.scan op ne as -- | Remove all those elements of `as` that do not satisfy the -- predicate `p`. @@ -169,7 +169,7 @@ def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a = -- -- **Span:** *O(log(n) ✕ W(p))* def filter [n] 'a (p: a -> bool) (as: [n]a): *[]a = - let (as', is) = intrinsics.partition (1, \x -> if p x then 0 else 1, as) + let (as', is) = intrinsics.partition 1 (\x -> if p x then 0 else 1) as in as'[:is[0]] -- | Split an array into those elements that satisfy the given @@ -180,7 +180,7 @@ def filter [n] 'a (p: a -> bool) (as: [n]a): *[]a = -- **Span:** *O(log(n) ✕ W(p))* def partition [n] 'a (p: a -> bool) (as: [n]a): ([]a, []a) = let p' x = if p x then 0 else 1 - let (as', is) = intrinsics.partition (2, p', as) + let (as', is) = intrinsics.partition 2 p' as in (as'[0:is[0]], as'[is[0]:n]) -- | Split an array by two predicates, producing three arrays. @@ -190,7 +190,7 @@ def partition [n] 'a (p: a -> bool) (as: [n]a): ([]a, []a) = -- **Span:** *O(log(n) ✕ (W(p1) + W(p2)))* def partition2 [n] 'a (p1: a -> bool) (p2: a -> bool) (as: [n]a): ([]a, []a, []a) = let p' x = if p1 x then 0 else if p2 x then 1 else 2 - let (as', is) = intrinsics.partition (3, p', as) + let (as', is) = intrinsics.partition 3 p' as in (as'[0:is[0]], as'[is[0]:is[0]+is[1]], as'[is[0]+is[1]:n]) -- | Return `true` if the given function returns `true` for all @@ -223,7 +223,7 @@ def any [n] 'a (f: a -> bool) (as: [n]a): bool = -- -- **Span:** *O(1)* def spread 't [n] (k: i64) (x: t) (is: [n]i64) (vs: [n]t): *[k]t = - intrinsics.scatter (map (\_ -> x) (0..1.. x) (0..1.. x) (as: [n]a): [n]x = - intrinsics.map (f, as) + intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = - intrinsics.zip (as, bs) + intrinsics.zip as bs -- | Construct an array of pairs from two arrays. def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = diff --git a/src/Futhark/Doc/Generator.hs b/src/Futhark/Doc/Generator.hs index 81e7fdcf1d..574acf158d 100644 --- a/src/Futhark/Doc/Generator.hs +++ b/src/Futhark/Doc/Generator.hs @@ -419,7 +419,7 @@ valBindHtml name (ValBind _ _ retdecl (Info rettype) tparams params _ _ _ _) = d map typeParamName tparams ++ map identName (S.toList $ mconcat $ map patIdents params) rettype' <- noLink' $ maybe (retTypeHtml rettype) typeExpHtml retdecl - params' <- noLink' $ mapM patternHtml params + params' <- noLink' $ mapM paramHtml params pure ( keyword "val " <> (H.span ! A.class_ "decl_name") name, tparams', @@ -493,6 +493,10 @@ synopsisValBindBind (name, BoundV tps t) = do <> ": " <> t' +dietHtml :: Diet -> Html +dietHtml Consume = "*" +dietHtml Observe = "" + typeHtml :: StructType -> DocM Html typeHtml t = case t of Array _ u shape et -> do @@ -513,14 +517,14 @@ typeHtml t = case t of targs' <- mapM typeArgHtml targs et' <- qualNameHtml et pure $ prettyU u <> et' <> mconcat (map (" " <>) targs') - Scalar (Arrow _ pname t1 t2) -> do + Scalar (Arrow _ pname d t1 t2) -> do t1' <- typeHtml t1 t2' <- retTypeHtml t2 pure $ case pname of Named v -> - parens (vnameHtml v <> ": " <> t1') <> " -> " <> t2' + parens (vnameHtml v <> ": " <> dietHtml d <> t1') <> " -> " <> t2' Unnamed -> - t1' <> " -> " <> t2' + dietHtml d <> t1' <> " -> " <> t2' Scalar (Sum cs) -> pipes <$> mapM ppClause (sortConstrs cs) where ppClause (n, ts) = joinBy " " . (ppConstr n :) <$> mapM typeHtml ts @@ -688,12 +692,12 @@ vnameLink' (VName _ tag) current file = then "#" ++ show tag else relativise file current ++ ".html#" ++ show tag -patternHtml :: Pat -> DocM Html -patternHtml pat = do - let (pat_param, t) = patternParam pat +paramHtml :: Pat -> DocM Html +paramHtml pat = do + let (pat_param, d, t) = patternParam pat t' <- typeHtml t pure $ case pat_param of - Named v -> parens (vnameHtml v <> ": " <> t') + Named v -> parens (vnameHtml v <> ": " <> dietHtml d <> t') Unnamed -> t' relativise :: FilePath -> FilePath -> FilePath diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 00bfa9fdff..36c5d20e05 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -384,7 +384,7 @@ defuncFun tparams pats e0 ret loc = do [pat'] -> (pat', ret, e0) (pat' : pats') -> ( pat', - RetType [] $ foldFunType (map (toStruct . patternType) pats') ret, + RetType [] $ funType pats' ret, Lambda pats' e0 Nothing (Info (mempty, ret)) loc ) @@ -533,15 +533,6 @@ defuncExp (AppExp (If e1 e2 e3 loc) res) = do (e2', sv) <- defuncExp e2 (e3', _) <- defuncExp e3 pure (AppExp (If e1' e2' e3' loc) res, sv) -defuncExp e@(AppExp (Apply f@(Var f' _ _) arg d loc) res) - | baseTag (qualLeaf f') <= maxIntrinsicTag, - TupLit es tuploc <- arg = do - -- defuncSoacExp also works fine for non-SOACs. - es' <- mapM defuncSoacExp es - pure - ( AppExp (Apply f (TupLit es' tuploc) d loc) res, - Dynamic $ typeOf e - ) defuncExp e@(AppExp Apply {} _) = defuncApply 0 e defuncExp (Negate e0 loc) = do (e0', sv) <- defuncExp e0 @@ -720,15 +711,19 @@ etaExpand e_t e = do (Info (AppRes (foldFunType argtypes ret) [])) ) e - $ zip3 vars (map snd ps) (drop 1 $ tails $ map snd ps) + $ zip3 vars (map (snd . snd) ps) (drop 1 $ tails $ map snd ps) pure (pats, e', second (const ()) ret) where - getType (RetType _ (Scalar (Arrow _ p t1 t2))) = - let (ps, r) = getType t2 in ((p, t1) : ps, r) + getType (RetType _ (Scalar (Arrow _ p d t1 t2))) = + let (ps, r) = getType t2 in ((p, (d, t1)) : ps, r) getType t = ([], t) - f prev (p, t) = do - let t' = fromStruct t + f prev (p, (d, t)) = do + let t' = + fromStruct t + `setUniqueness` case d of + Consume -> Unique + Observe -> Nonunique x <- case p of Named x | x `notElem` prev -> pure x _ -> newNameFromString "x" @@ -820,10 +815,9 @@ defuncApply :: Int -> Exp -> DefM (Exp, StaticVal) defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do let (argtypes, _) = unfoldFunType ret (e1', sv1) <- defuncApply (depth + 1) e1 - (e2', sv2) <- defuncExp e2 - let e' = AppExp (Apply e1' e2' d loc) t case sv1 of LambdaSV pat e0_t e0 closure_env -> do + (e2', sv2) <- defuncExp e2 let env' = matchPatSV pat sv2 dims = mempty (e0', sv) <- @@ -880,13 +874,15 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do let t1 = toStruct $ typeOf e1' t2 = toStruct $ typeOf e2' + d1 = Observe + d2 = Observe fname' = qualName fname fname'' = Var fname' ( Info - ( Scalar . Arrow mempty Unnamed t1 . RetType [] $ - Scalar . Arrow mempty Unnamed t2 $ + ( Scalar . Arrow mempty Unnamed d1 t1 . RetType [] $ + Scalar . Arrow mempty Unnamed d2 t2 $ RetType [] lifted_rettype ) ) @@ -896,7 +892,7 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do innercallret = AppRes - (Scalar $ Arrow mempty Unnamed t2 $ RetType [] lifted_rettype) + (Scalar $ Arrow mempty Unnamed d2 t2 $ RetType [] lifted_rettype) [] pure @@ -918,6 +914,7 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do -- but we update the types since it may be partially applied or return -- a higher-order term. DynamicFun _ sv -> do + (e2', _) <- defuncExp e2 let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) `setAliases` aliases ret callret = AppRes (combineTypeShapes ret restype) ext @@ -925,8 +922,14 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do pure (apply_e, sv) -- Propagate the 'IntrinsicsSV' until we reach the outermost application, -- where we construct a dynamic static value with the appropriate type. - IntrinsicSV -> intrinsicOrHole argtypes e' sv1 - HoleSV {} -> intrinsicOrHole argtypes e' sv1 + IntrinsicSV -> do + e2' <- defuncSoacExp e2 + let e' = AppExp (Apply e1' e2' d loc) t + intrinsicOrHole argtypes e' sv1 + HoleSV {} -> do + (e2', _) <- defuncExp e2 + let e' = AppExp (Apply e1' e2' d loc) t + intrinsicOrHole argtypes e' sv1 _ -> error $ "Application of an expression\n" @@ -1129,9 +1132,10 @@ typeFromSV IntrinsicSV = -- | Construct the type for a fully-applied dynamic function from its -- static value and the original types of its arguments. -dynamicFunType :: StaticVal -> [StructType] -> ([PatType], PatType) +dynamicFunType :: StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType) dynamicFunType (DynamicFun _ sv) (p : ps) = - let (ps', ret) = dynamicFunType sv ps in (fromStruct p : ps', ret) + let (ps', ret) = dynamicFunType sv ps + in (second fromStruct p : ps', ret) dynamicFunType sv _ = ([], typeFromSV sv) -- | Match a pattern with its static value. Returns an environment with diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 1a45cfd2ba..031c304cc0 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1602,13 +1602,13 @@ isIntrinsicFunction qname args loc = do fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing - handleSOACs [TupLit [lam, arr] _] "map" = Just $ \desc -> do + handleSOACs [lam, arr] "map" = Just $ \desc -> do arr' <- internaliseExpToVars "map_arr" arr arr_ts <- mapM lookupType arr' lam' <- internaliseLambdaCoerce lam $ map rowType arr_ts let w = arraysSize 0 arr_ts letTupExp' desc $ I.Op $ I.Screma w arr' (I.mapSOAC lam') - handleSOACs [TupLit [k, lam, arr] _] "partition" = do + handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do arrs <- internaliseExpToVars "partition_input" arr @@ -1618,43 +1618,43 @@ isIntrinsicFunction qname args loc = do fromInt32 (Literal (SignedValue (Int32Value k')) _) = Just k' fromInt32 (IntLit k' (Info (E.Scalar (E.Prim (E.Signed Int32)))) _) = Just $ fromInteger k' fromInt32 _ = Nothing - handleSOACs [TupLit [lam, ne, arr] _] "reduce" = Just $ \desc -> + handleSOACs [lam, ne, arr] "reduce" = Just $ \desc -> internaliseScanOrReduce desc "reduce" reduce (lam, ne, arr, loc) where reduce w red_lam nes arrs = I.Screma w arrs <$> I.reduceSOAC [Reduce Noncommutative red_lam nes] - handleSOACs [TupLit [lam, ne, arr] _] "reduce_comm" = Just $ \desc -> + handleSOACs [lam, ne, arr] "reduce_comm" = Just $ \desc -> internaliseScanOrReduce desc "reduce" reduce (lam, ne, arr, loc) where reduce w red_lam nes arrs = I.Screma w arrs <$> I.reduceSOAC [Reduce Commutative red_lam nes] - handleSOACs [TupLit [lam, ne, arr] _] "scan" = Just $ \desc -> + handleSOACs [lam, ne, arr] "scan" = Just $ \desc -> internaliseScanOrReduce desc "scan" reduce (lam, ne, arr, loc) where reduce w scan_lam nes arrs = I.Screma w arrs <$> I.scanSOAC [Scan scan_lam nes] - handleSOACs [TupLit [rf, dest, op, ne, buckets, img] _] "hist_1d" = Just $ \desc -> + handleSOACs [rf, dest, op, ne, buckets, img] "hist_1d" = Just $ \desc -> internaliseHist 1 desc rf dest op ne buckets img loc - handleSOACs [TupLit [rf, dest, op, ne, buckets, img] _] "hist_2d" = Just $ \desc -> + handleSOACs [rf, dest, op, ne, buckets, img] "hist_2d" = Just $ \desc -> internaliseHist 2 desc rf dest op ne buckets img loc - handleSOACs [TupLit [rf, dest, op, ne, buckets, img] _] "hist_3d" = Just $ \desc -> + handleSOACs [rf, dest, op, ne, buckets, img] "hist_3d" = Just $ \desc -> internaliseHist 3 desc rf dest op ne buckets img loc handleSOACs _ _ = Nothing - handleAccs [TupLit [dest, f, bs] _] "scatter_stream" = Just $ \desc -> + handleAccs [dest, f, bs] "scatter_stream" = Just $ \desc -> internaliseStreamAcc desc dest Nothing f bs - handleAccs [TupLit [dest, op, ne, f, bs] _] "hist_stream" = Just $ \desc -> + handleAccs [dest, op, ne, f, bs] "hist_stream" = Just $ \desc -> internaliseStreamAcc desc dest (Just (op, ne)) f bs - handleAccs [TupLit [acc, i, v] _] "acc_write" = Just $ \desc -> do + handleAccs [acc, i, v] "acc_write" = Just $ \desc -> do acc' <- head <$> internaliseExpToVars "acc" acc i' <- internaliseExp1 "acc_i" i vs <- internaliseExp "acc_v" v fmap pure $ letSubExp desc $ BasicOp $ UpdateAcc acc' [i'] vs handleAccs _ _ = Nothing - handleAD [TupLit [f, x, v] _] fname + handleAD [f, x, v] fname | fname `elem` ["jvp2", "vjp2"] = Just $ \desc -> do x' <- internaliseExp "ad_x" x v' <- internaliseExp "ad_v" v @@ -1665,10 +1665,10 @@ isIntrinsicFunction qname args loc = do _ -> VJP lam x' v' handleAD _ _ = Nothing - handleRest [E.TupLit [a, si, v] _] "scatter" = Just $ scatterF 1 a si v - handleRest [E.TupLit [a, si, v] _] "scatter_2d" = Just $ scatterF 2 a si v - handleRest [E.TupLit [a, si, v] _] "scatter_3d" = Just $ scatterF 3 a si v - handleRest [E.TupLit [n, m, arr] _] "unflatten" = Just $ \desc -> do + handleRest [a, si, v] "scatter" = Just $ scatterF 1 a si v + handleRest [a, si, v] "scatter_2d" = Just $ scatterF 2 a si v + handleRest [a, si, v] "scatter_3d" = Just $ scatterF 3 a si v + handleRest [n, m, arr] "unflatten" = Just $ \desc -> do arrs <- internaliseExpToVars "unflatten_arr" arr n' <- internaliseExp1 "n" n m' <- internaliseExp1 "m" m @@ -1716,7 +1716,7 @@ isIntrinsicFunction qname args loc = do I.ReshapeArbitrary (reshapeOuter (I.Shape [k]) 2 $ I.arrayShape arr_t) arr' - handleRest [TupLit [x, y] _] "concat" = Just $ \desc -> do + handleRest [x, y] "concat" = Just $ \desc -> do xs <- internaliseExpToVars "concat_x" x ys <- internaliseExpToVars "concat_y" y outer_size <- arraysSize 0 <$> mapM lookupType xs @@ -1731,7 +1731,7 @@ isIntrinsicFunction qname args loc = do let conc xarr yarr = I.BasicOp $ I.Concat 0 (xarr :| [yarr]) ressize mapM (letSubExp desc) $ zipWith conc xs ys - handleRest [TupLit [offset, e] _] "rotate" = Just $ \desc -> do + handleRest [offset, e] "rotate" = Just $ \desc -> do offset' <- internaliseExp1 "rotation_offset" offset internaliseOperation desc e $ \v -> do r <- I.arrayRank <$> lookupType v @@ -1742,24 +1742,24 @@ isIntrinsicFunction qname args loc = do internaliseOperation desc e $ \v -> do r <- I.arrayRank <$> lookupType v pure $ I.Rearrange ([1, 0] ++ [2 .. r - 1]) v - handleRest [TupLit [x, y] _] "zip" = Just $ \desc -> + handleRest [x, y] "zip" = Just $ \desc -> mapM (letSubExp "zip_copy" . BasicOp . Copy) =<< ( (++) <$> internaliseExpToVars (desc ++ "_zip_x") x <*> internaliseExpToVars (desc ++ "_zip_y") y ) handleRest [x] "unzip" = Just $ flip internaliseExp x - handleRest [TupLit [arr, offset, n1, s1, n2, s2] _] "flat_index_2d" = Just $ \desc -> do + handleRest [arr, offset, n1, s1, n2, s2] "flat_index_2d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2)] - handleRest [TupLit [arr1, offset, s1, s2, arr2] _] "flat_update_2d" = Just $ \desc -> do + handleRest [arr1, offset, s1, s2, arr2] "flat_update_2d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2] arr2 - handleRest [TupLit [arr, offset, n1, s1, n2, s2, n3, s3] _] "flat_index_3d" = Just $ \desc -> do + handleRest [arr, offset, n1, s1, n2, s2, n3, s3] "flat_index_3d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2), (n3, s3)] - handleRest [TupLit [arr1, offset, s1, s2, s3, arr2] _] "flat_update_3d" = Just $ \desc -> do + handleRest [arr1, offset, s1, s2, s3, arr2] "flat_update_3d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2, s3] arr2 - handleRest [TupLit [arr, offset, n1, s1, n2, s2, n3, s3, n4, s4] _] "flat_index_4d" = Just $ \desc -> do + handleRest [arr, offset, n1, s1, n2, s2, n3, s3, n4, s4] "flat_index_4d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2), (n3, s3), (n4, s4)] - handleRest [TupLit [arr1, offset, s1, s2, s3, s4, arr2] _] "flat_update_4d" = Just $ \desc -> do + handleRest [arr1, offset, s1, s2, s3, s4, arr2] "flat_update_4d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2, s3, s4] arr2 handleRest _ _ = Nothing diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index b9e14dab2b..643639473a 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -221,7 +221,7 @@ transformFName loc fname t ( i - 1, AppExp (Apply f size_arg (Info (Observe, Nothing)) loc) - (Info $ AppRes (foldFunType (replicate i i64) (RetType [] (fromStruct t))) []) + (Info $ AppRes (foldFunType (replicate i (Observe, i64)) (RetType [] (fromStruct t))) []) ) applySizeArgs fname' t' size_args = @@ -233,7 +233,7 @@ transformFName loc fname t (qualName fname') ( Info ( foldFunType - (map (const i64) size_args) + (map (const (Observe, i64)) size_args) (RetType [] $ fromStruct t') ) ) @@ -540,7 +540,7 @@ desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (Ret (Info (Observe, xext)) loc ) - (Info $ AppRes (Scalar $ Arrow mempty yp ytype (RetType [] t)) []) + (Info $ AppRes (Scalar $ Arrow mempty yp Observe ytype (RetType [] t)) []) rettype' = let onDim (NamedSize d) | Named p <- xp, qualLeaf d == p = NamedSize $ qualName v1 @@ -580,7 +580,7 @@ desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (Ret pure (v, id, var_e, [pat]) desugarProjectSection :: [Name] -> PatType -> SrcLoc -> MonoM Exp -desugarProjectSection fields (Scalar (Arrow _ _ t1 (RetType dims t2))) loc = do +desugarProjectSection fields (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do p <- newVName "project_p" let body = foldl project (Var (qualName p) (Info t1') mempty) fields pure $ @@ -606,7 +606,7 @@ desugarProjectSection fields (Scalar (Arrow _ _ t1 (RetType dims t2))) loc = do desugarProjectSection _ t _ = error $ "desugarOpSection: not a function type: " ++ prettyString t desugarIndexSection :: [DimIndex] -> PatType -> SrcLoc -> MonoM Exp -desugarIndexSection idxs (Scalar (Arrow _ _ t1 (RetType dims t2))) loc = do +desugarIndexSection idxs (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do p <- newVName "index_i" let body = AppExp (Index (Var (qualName p) (Info t1') loc) idxs loc) (Info (AppRes t2 [])) pure $ @@ -719,8 +719,8 @@ noNamedParams = f where f (Array () u shape t) = Array () u shape (f' t) f (Scalar t) = Scalar $ f' t - f' (Arrow () _ t1 (RetType dims t2)) = - Arrow () Unnamed (f t1) (RetType dims (f t2)) + f' (Arrow () _ d1 t1 (RetType dims t2)) = + Arrow () Unnamed d1 (f t1) (RetType dims (f t2)) f' (Record fs) = Record $ fmap f fs f' (Sum cs) = @@ -737,7 +737,7 @@ monomorphiseBinding :: MonoM (VName, InferSizeArgs, ValBind) monomorphiseBinding entry (PolyBinding rr (name, tparams, params, rettype, body, attrs, loc)) inst_t = replaceRecordReplacements rr $ do - let bind_t = foldFunType (map patternStructType params) rettype + let bind_t = funType params rettype (substs, t_shape_params) <- typeSubstsM loc (noSizes bind_t) $ noNamedParams inst_t let substs' = M.map (Subst []) substs rettype' = applySubst (`M.lookup` substs') rettype @@ -840,7 +840,7 @@ typeSubstsM loc orig_t1 orig_t2 = (map snd $ sortFields fields1) (map snd $ sortFields fields2) sub (Scalar Prim {}) (Scalar Prim {}) = pure () - sub (Scalar (Arrow _ _ t1a (RetType _ t1b))) (Scalar (Arrow _ _ t2a t2b)) = do + sub (Scalar (Arrow _ _ _ t1a (RetType _ t1b))) (Scalar (Arrow _ _ _ t2a t2b)) = do sub t1a t2a subRet t1b t2b sub (Scalar (Sum cs1)) (Scalar (Sum cs2)) = @@ -940,11 +940,10 @@ transformValBind valbind = do Nothing -> pure () Just (Info entry) -> do t <- - removeTypeVariablesInType - $ foldFunType - (map patternStructType (valBindParams valbind)) - $ unInfo - $ valBindRetType valbind + removeTypeVariablesInType $ + funType (valBindParams valbind) $ + unInfo $ + valBindRetType valbind (name, infer, valbind'') <- monomorphiseBinding True valbind' $ monoType t entry' <- transformEntryPoint entry tell $ Seq.singleton (name, valbind'' {valBindEntryPoint = Just $ Info entry'}) diff --git a/src/Language/Futhark/FreeVars.hs b/src/Language/Futhark/FreeVars.hs index 5e8c0dfe83..76c89d73dd 100644 --- a/src/Language/Futhark/FreeVars.hs +++ b/src/Language/Futhark/FreeVars.hs @@ -148,7 +148,7 @@ freeInType t = mempty Scalar (Sum cs) -> foldMap (foldMap freeInType) cs - Scalar (Arrow _ v t1 (RetType dims t2)) -> + Scalar (Arrow _ v _ t1 (RetType dims t2)) -> S.filter (notV v) $ S.filter (`notElem` dims) $ freeInType t1 <> freeInType t2 Scalar (TypeVar _ _ _ targs) -> foldMap typeArgDims targs diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 5b3ee8f216..d430061039 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -154,8 +154,8 @@ resolveTypeParams names = match M.elems $ M.intersectionWith (zipWith match) poly_fields fields match - (Scalar (Arrow _ _ poly_t1 (RetType _ poly_t2))) - (Scalar (Arrow _ _ t1 (RetType _ t2))) = + (Scalar (Arrow _ _ _ poly_t1 (RetType _ poly_t2))) + (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = match poly_t1 t1 <> match poly_t2 t2 match poly_t t | d1 : _ <- shapeDims (arrayShape poly_t), @@ -547,8 +547,8 @@ evalIndex loc env is arr = do evalType :: Env -> StructType -> StructType evalType _ (Scalar (Prim pt)) = Scalar $ Prim pt evalType env (Scalar (Record fs)) = Scalar $ Record $ fmap (evalType env) fs -evalType env (Scalar (Arrow () p t1 (RetType dims t2))) = - Scalar $ Arrow () p (evalType env t1) (RetType dims (evalType env t2)) +evalType env (Scalar (Arrow () p d t1 (RetType dims t2))) = + Scalar $ Arrow () p d (evalType env t1) (RetType dims (evalType env t2)) evalType env t@(Array _ u shape _) = let et = stripArray (shapeRank shape) t et' = evalType env et @@ -611,7 +611,7 @@ evalFunction env _ [] body rettype = -- Eta-expand the rest to make any sizes visible. etaExpand [] env rettype where - etaExpand vs env' (Scalar (Arrow _ _ pt (RetType _ rt))) = + etaExpand vs env' (Scalar (Arrow _ _ _ pt (RetType _ rt))) = pure $ ValueFun $ \v -> do env'' <- matchPat env' (Wildcard (Info $ fromStruct pt) noLoc) v @@ -640,7 +640,7 @@ evalFunctionBinding :: EvalM TermBinding evalFunctionBinding env tparams ps ret fbody = do let ret' = evalType env $ retType ret - arrow (xp, xt) yt = Scalar $ Arrow () xp xt $ RetType [] yt + arrow (xp, d, xt) yt = Scalar $ Arrow () xp d xt $ RetType [] yt ftype = foldr (arrow . patternParam) ret' ps retext = case ps of [] -> retDims ret @@ -1231,57 +1231,61 @@ initialCtx = fun1 f = TermValue Nothing $ ValueFun $ \x -> f x + fun2 f = - TermValue Nothing $ - ValueFun $ \x -> - pure $ ValueFun $ \y -> f x y - fun2t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y] -> f x y - _ -> error $ "Expected pair; got: " <> show v - fun3t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z] -> f x y z - _ -> error $ "Expected triple; got: " <> show v - - fun5t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b] -> f x y z a b - _ -> error $ "Expected pentuple; got: " <> show v - - fun6t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c] -> f x y z a b c - _ -> error $ "Expected sextuple; got: " <> show v - - fun7t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c, d] -> f x y z a b c d - _ -> error $ "Expected septuple; got: " <> show v - - fun8t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c, d, e] -> f x y z a b c d e - _ -> error $ "Expected sextuple; got: " <> show v - - fun10t fun = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c, d, e, f, g] -> fun x y z a b c d e f g - _ -> error $ "Expected octuple; got: " <> show v + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> f x y + + fun3 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> f x y z + + fun5 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> f x y z a b + + fun6 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> f x y z a b c + + fun7 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> + pure . ValueFun $ \d -> f x y z a b c d + + fun8 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> + pure . ValueFun $ \d -> + pure . ValueFun $ \e -> f x y z a b c d e + + fun10 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> + pure . ValueFun $ \d -> + pure . ValueFun $ \e -> + pure . ValueFun $ \g -> + pure . ValueFun $ \h -> f x y z a b c d e g h bopDef fs = fun2 $ \x y -> case (x, y) of @@ -1459,14 +1463,14 @@ initialCtx = _ -> error $ "Cannot unsign: " <> show x def s | "map_stream" `isPrefixOf` s = - Just $ fun2t stream + Just $ fun2 stream def s | "reduce_stream" `isPrefixOf` s = - Just $ fun3t $ \_ f arg -> stream f arg + Just $ fun3 $ \_ f arg -> stream f arg def "map" = Just $ TermPoly Nothing $ \t -> pure $ - ValueFun $ \v -> - case (fromTuple v, unfoldFunType t) of - (Just [f, xs], ([_], ret_t)) + ValueFun $ \f -> pure . ValueFun $ \xs -> + case unfoldFunType t of + ([_, _], ret_t) | Just rowshape <- typeRowShape ret_t -> toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) | otherwise -> @@ -1474,21 +1478,21 @@ initialCtx = _ -> error $ "Invalid arguments to map intrinsic:\n" - ++ unlines [prettyString t, show v] + ++ unlines [prettyString t, show f, show xs] where typeRowShape = sequenceA . structTypeShape mempty . stripArray 1 def s | "reduce" `isPrefixOf` s = Just $ - fun3t $ \f ne xs -> + fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs def "scan" = Just $ - fun3t $ \f ne xs -> do + fun3 $ \f ne xs -> do let next (out, acc) x = do x' <- apply2 noLoc mempty f acc x pure (x' : out, x') toArray' (valueShape ne) . reverse . fst <$> foldM next ([], ne) (snd $ fromArray xs) def "scatter" = Just $ - fun3t $ \arr is vs -> + fun3 $ \arr is vs -> case arr of ValueArray shape arr' -> pure $ @@ -1503,7 +1507,7 @@ initialCtx = then arr' // [(i, v)] else arr' def "scatter_2d" = Just $ - fun3t $ \arr is vs -> + fun3 $ \arr is vs -> case arr of ValueArray _ _ -> pure $ @@ -1518,7 +1522,7 @@ initialCtx = update _ _ = error "scatter_2d expects 2-dimensional indices" def "scatter_3d" = Just $ - fun3t $ \arr is vs -> + fun3 $ \arr is vs -> case arr of ValueArray _ _ -> pure $ @@ -1532,7 +1536,7 @@ initialCtx = fromMaybe arr $ writeArray (map (IndexingFix . asInt64) idxs) arr v update _ _ = error "scatter_3d expects 3-dimensional indices" - def "hist_1d" = Just . fun6t $ \_ arr fun _ is vs -> + def "hist_1d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr @@ -1541,7 +1545,7 @@ initialCtx = op = apply2 mempty mempty update fun arr (i, v) = fromMaybe arr <$> updateArray (op fun) [IndexingFix i] arr v - def "hist_2d" = Just . fun6t $ \_ arr fun _ is vs -> + def "hist_2d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr @@ -1553,7 +1557,7 @@ initialCtx = <$> updateArray (op fun) (map (IndexingFix . asInt64) idxs) arr v update _ _ _ = error "hist_2d: bad index value" - def "hist_3d" = Just . fun6t $ \_ arr fun _ is vs -> + def "hist_3d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr @@ -1566,7 +1570,7 @@ initialCtx = update _ _ _ = error "hist_2d: bad index value" def "partition" = Just $ - fun3t $ \k f xs -> do + fun3 $ \k f xs -> do let (ShapeDim _ rowshape, xs') = fromArray xs next outs x = do @@ -1586,7 +1590,7 @@ initialCtx = insertAt i x (l : ls) = l : insertAt (i - 1) x ls insertAt _ _ ls = ls def "scatter_stream" = Just $ - fun3t $ \dest f vs -> + fun3 $ \dest f vs -> case (dest, vs) of ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr @@ -1601,7 +1605,7 @@ initialCtx = _ -> error $ "scatter_stream expects array, but got: " <> prettyString (show vs, show vs) def "hist_stream" = Just $ - fun5t $ \dest op _ne f vs -> + fun5 $ \dest op _ne f vs -> case (dest, vs) of ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr @@ -1616,7 +1620,7 @@ initialCtx = _ -> error $ "hist_stream expects array, but got: " <> prettyString (show dest, show vs) def "acc_write" = Just $ - fun3t $ \acc i v -> + fun3 $ \acc i v -> case (acc, i) of ( ValueAcc op acc_arr, ValuePrim (SignedValue (Int64Value i')) @@ -1630,7 +1634,7 @@ initialCtx = _ -> error $ "acc_write invalid arguments: " <> prettyString (show acc, show i, show v) -- - def "flat_index_2d" = Just . fun6t $ \arr offset n1 s1 n2 s2 -> do + def "flat_index_2d" = Just . fun6 $ \arr offset n1 s1 n2 s2 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 @@ -1649,7 +1653,7 @@ initialCtx = bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1', s1'), (n2', s2'))] -- - def "flat_update_2d" = Just . fun5t $ \arr offset s1 s2 v -> do + def "flat_update_2d" = Just . fun5 $ \arr offset s1 s2 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 @@ -1666,7 +1670,7 @@ initialCtx = "Index out of bounds: " <> prettyText [((n1, s1'), (n2, s2'))] s -> error $ "flat_update_2d: invalid arg shape: " ++ show s -- - def "flat_index_3d" = Just . fun8t $ \arr offset n1 s1 n2 s2 n3 s3 -> do + def "flat_index_3d" = Just . fun8 $ \arr offset n1 s1 n2 s2 n3 s3 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 @@ -1688,7 +1692,7 @@ initialCtx = bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1', s1'), (n2', s2'), (n3', s3'))] -- - def "flat_update_3d" = Just . fun6t $ \arr offset s1 s2 s3 v -> do + def "flat_update_3d" = Just . fun6 $ \arr offset s1 s2 s3 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 @@ -1706,7 +1710,7 @@ initialCtx = "Index out of bounds: " <> prettyText [((n1, s1'), (n2, s2'), (n3, s3'))] s -> error $ "flat_update_3d: invalid arg shape: " ++ show s -- - def "flat_index_4d" = Just . fun10t $ \arr offset n1 s1 n2 s2 n3 s3 n4 s4 -> do + def "flat_index_4d" = Just . fun10 $ \arr offset n1 s1 n2 s2 n3 s3 n4 s4 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 @@ -1731,7 +1735,7 @@ initialCtx = bad mempty mempty $ "Index out of bounds: " <> prettyText [(((n1', s1'), (n2', s2')), ((n3', s3'), (n4', s4')))] -- - def "flat_update_4d" = Just . fun7t $ \arr offset s1 s2 s3 s4 v -> do + def "flat_update_4d" = Just . fun7 $ \arr offset s1 s2 s3 s4 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 @@ -1762,7 +1766,7 @@ initialCtx = fromPair (Just [x, y]) = (x, y) fromPair _ = error "Not a pair" def "zip" = Just $ - fun2t $ \xs ys -> do + fun2 $ \xs ys -> do let ShapeDim _ xs_rowshape = valueShape xs ShapeDim _ ys_rowshape = valueShape ys pure $ @@ -1770,7 +1774,7 @@ initialCtx = map toTuple $ transpose [snd $ fromArray xs, snd $ fromArray ys] def "concat" = Just $ - fun2t $ \xs ys -> do + fun2 $ \xs ys -> do let (ShapeDim _ rowshape, xs') = fromArray xs (_, ys') = fromArray ys pure $ toArray' rowshape $ xs' ++ ys' @@ -1784,7 +1788,7 @@ initialCtx = genericTake m $ transpose (map (snd . fromArray) xs') ++ repeat [] def "rotate" = Just $ - fun2t $ \i xs -> do + fun2 $ \i xs -> do let (shape, xs') = fromArray xs pure $ let idx = if null xs' then 0 else rem (asInt i) (length xs') @@ -1800,7 +1804,7 @@ initialCtx = let (ShapeDim n (ShapeDim m shape), xs') = fromArray xs pure $ toArray (ShapeDim (n * m) shape) $ concatMap (snd . fromArray) xs' def "unflatten" = Just $ - fun3t $ \n m xs -> do + fun3 $ \n m xs -> do let (ShapeDim xs_size innershape, xs') = fromArray xs rowshape = ShapeDim (asInt64 m) innershape shape = ShapeDim (asInt64 n) rowshape @@ -1816,10 +1820,10 @@ initialCtx = <> "]" else pure $ toArray shape $ map (toArray rowshape) $ chunk (asInt m) xs' def "vjp2" = Just $ - fun3t $ + fun3 $ \_ _ _ -> bad noLoc mempty "Interpreter does not support autodiff." def "jvp2" = Just $ - fun3t $ + fun3 $ \_ _ _ -> bad noLoc mempty "Interpreter does not support autodiff." def "acc" = Nothing def s | nameFromString s `M.member` namesToPrimTypes = Nothing @@ -1878,7 +1882,7 @@ valueType v = checkEntryArgs :: VName -> [V.Value] -> StructType -> Either T.Text () checkEntryArgs entry args entry_t - | args_ts == param_ts = + | args_ts == map snd param_ts = pure () | otherwise = Left . docText $ @@ -1916,9 +1920,9 @@ interpretFunction ctx fname vs = do f <- evalTermVar (ctxEnv ctx) (qualName fname) ft foldM (apply noLoc mempty) f vs' where - updateType (vt : vts) (Scalar (Arrow als u pt (RetType dims rt))) = do + updateType (vt : vts) (Scalar (Arrow als pn d pt (RetType dims rt))) = do checkInput vt pt - Scalar . Arrow als u (valueStructType vt) . RetType dims <$> updateType vts rt + Scalar . Arrow als pn d (valueStructType vt) . RetType dims <$> updateType vts rt updateType _ t = Right t diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 99be6a3840..210afc3fe6 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -115,6 +115,10 @@ prettyRetType _ (RetType dims t) = instance Pretty (Shape dim) => Pretty (RetTypeBase dim as) where pretty = prettyRetType 0 +instance Pretty Diet where + pretty Consume = "*" + pretty Observe = "" + prettyScalarType :: Pretty (Shape dim) => Int -> ScalarTypeBase dim as -> Doc a prettyScalarType _ (Prim et) = pretty et prettyScalarType p (TypeVar _ u v targs) = @@ -128,11 +132,16 @@ prettyScalarType _ (Record fs) where ppField (name, t) = pretty (nameToString name) <> colon <+> align (pretty t) fs' = map ppField $ M.toList fs -prettyScalarType p (Arrow _ (Named v) t1 t2) = +prettyScalarType p (Arrow _ (Named v) d t1 t2) = + parensIf (p > 1) $ + parens (prettyName v <> colon <+> pretty d <> align (pretty t1)) + <+> "->" + <+> prettyRetType 1 t2 +prettyScalarType p (Arrow _ Unnamed d t1 t2) = parensIf (p > 1) $ - parens (prettyName v <> colon <+> align (pretty t1)) <+> "->" <+> prettyRetType 1 t2 -prettyScalarType p (Arrow _ Unnamed t1 t2) = - parensIf (p > 1) $ prettyType 2 t1 <+> "->" <+> prettyRetType 1 t2 + (pretty d <> prettyType 2 t1) + <+> "->" + <+> prettyRetType 1 t2 prettyScalarType p (Sum cs) = parensIf (p > 0) $ group (align (mconcat $ punctuate (" |" <> line) cs')) @@ -353,7 +362,7 @@ prettyExp p (Lambda params body rettype _ _) = parensIf (p /= -1) $ "\\" <> hsep (map pretty params) <> ppAscription rettype <+> "->" - indent 2 (pretty body) + indent 2 (align (pretty body)) prettyExp _ (OpSection binop _ _) = parens $ pretty binop prettyExp _ (OpSectionLeft binop _ x _ _ _) = diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 73a1c26ec4..de6d7be73b 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -48,6 +48,7 @@ module Language.Futhark.Prop orderZero, unfoldFunType, foldFunType, + foldFunTypeFromParams, typeVars, -- * Operations on types @@ -174,8 +175,8 @@ traverseDims f = go mempty PosImmediate Scalar . Sum <$> traverse (traverse (go bound b)) cs go _ _ (Scalar (Prim t)) = pure $ Scalar $ Prim t - go bound _ (Scalar (Arrow als p t1 (RetType dims t2))) = - Scalar <$> (Arrow als p <$> go bound' PosParam t1 <*> (RetType dims <$> go bound' PosReturn t2)) + go bound _ (Scalar (Arrow als p u t1 (RetType dims t2))) = + Scalar <$> (Arrow als p u <$> go bound' PosParam t1 <*> (RetType dims <$> go bound' PosReturn t2)) where bound' = S.fromList dims @@ -208,16 +209,16 @@ aliases :: Monoid as => TypeBase shape as -> as aliases = bifoldMap (const mempty) id -- | @diet t@ returns a description of how a function parameter of --- type @t@ might consume its argument. +-- type @t@ consumes its argument. diet :: TypeBase shape as -> Diet -diet (Scalar (Record ets)) = RecordDiet $ fmap diet ets +diet (Scalar (Record ets)) = foldl max Observe $ fmap diet ets diet (Scalar (Prim _)) = Observe -diet (Scalar (Arrow _ _ t1 (RetType _ t2))) = FuncDiet (diet t1) (diet t2) +diet (Scalar (Arrow {})) = Observe diet (Array _ Unique _ _) = Consume diet (Array _ Nonunique _ _) = Observe diet (Scalar (TypeVar _ Unique _ _)) = Consume diet (Scalar (TypeVar _ Nonunique _ _)) = Observe -diet (Scalar (Sum cs)) = SumDiet $ M.map (map diet) cs +diet (Scalar (Sum cs)) = foldl max Observe $ foldMap (map diet) cs -- | Convert any type to one that has rank information, no alias -- information, and no embedded names. @@ -334,8 +335,14 @@ combineTypeShapes (Scalar (Sum cs1)) (Scalar (Sum cs2)) M.map (uncurry $ zipWith combineTypeShapes) (M.intersectionWith (,) cs1 cs2) -combineTypeShapes (Scalar (Arrow als1 p1 a1 (RetType dims1 b1))) (Scalar (Arrow als2 _p2 a2 (RetType _ b2))) = - Scalar $ Arrow (als1 <> als2) p1 (combineTypeShapes a1 a2) (RetType dims1 (combineTypeShapes b1 b2)) +combineTypeShapes (Scalar (Arrow als1 p1 d1 a1 (RetType dims1 b1))) (Scalar (Arrow als2 _p2 _d2 a2 (RetType _ b2))) = + Scalar $ + Arrow + (als1 <> als2) + p1 + d1 + (combineTypeShapes a1 a2) + (RetType dims1 (combineTypeShapes b1 b2)) combineTypeShapes (Scalar (TypeVar als1 u1 v targs1)) (Scalar (TypeVar als2 _ _ targs2)) = Scalar $ TypeVar (als1 <> als2) u1 v $ zipWith f targs1 targs2 where @@ -385,12 +392,12 @@ matchDims onDims = matchDims' mempty <$> traverse (traverse (uncurry (matchDims' bound))) (M.intersectionWith zip cs1 cs2) - ( Scalar (Arrow als1 p1 a1 (RetType dims1 b1)), - Scalar (Arrow als2 p2 a2 (RetType dims2 b2)) + ( Scalar (Arrow als1 p1 d1 a1 (RetType dims1 b1)), + Scalar (Arrow als2 p2 _d2 a2 (RetType dims2 b2)) ) -> let bound' = mapMaybe paramName [p1, p2] <> dims1 <> dims2 <> bound in Scalar - <$> ( Arrow (als1 <> als2) p1 + <$> ( Arrow (als1 <> als2) p1 d1 <$> matchDims' bound' a1 a2 <*> (RetType dims1 <$> matchDims' bound' b1 b2) ) @@ -488,45 +495,62 @@ typeOf (Update e _ _ _) = typeOf e `setAliases` mempty typeOf (RecordUpdate _ _ _ (Info t) _) = t typeOf (Assert _ e _ _) = typeOf e typeOf (Lambda params _ _ (Info (als, t)) _) = - let RetType [] t' = foldr (arrow . patternParam) t params - in t' `setAliases` als - where - arrow (Named v, x) (RetType dims y) = - RetType [] $ Scalar $ Arrow () (Named v) x $ RetType (v : dims) y - arrow (pn, tx) y = - RetType [] $ Scalar $ Arrow () pn tx y + funType params t `setAliases` als typeOf (OpSection _ (Info t) _) = t typeOf (OpSectionLeft _ _ _ (_, Info (pn, pt2)) (Info ret, _) _) = - Scalar $ Arrow mempty pn pt2 ret + Scalar $ Arrow mempty pn Observe pt2 ret typeOf (OpSectionRight _ _ _ (Info (pn, pt1), _) (Info ret) _) = - Scalar $ Arrow mempty pn pt1 ret + Scalar $ Arrow mempty pn Observe pt1 ret typeOf (ProjectSection _ (Info t) _) = t typeOf (IndexSection _ (Info t) _) = t typeOf (Constr _ _ (Info t) _) = t typeOf (Attr _ e _) = typeOf e typeOf (AppExp _ (Info res)) = appResType res +-- | The type of a function with the given parameters and return type. +funType :: [PatBase Info VName] -> StructRetType -> StructType +funType params ret = + let RetType _ t = foldr (arrow . patternParam) ret params + in t + where + arrow (xp, d, xt) yt = + RetType [] $ Scalar $ Arrow () xp d xt' yt + where + xt' = xt `setUniqueness` Nonunique + -- | @foldFunType ts ret@ creates a function type ('Arrow') that takes -- @ts@ as parameters and returns @ret@. foldFunType :: Monoid as => - [TypeBase dim pas] -> + [(Diet, TypeBase dim pas)] -> RetTypeBase dim as -> TypeBase dim as foldFunType ps ret = let RetType _ t = foldr arrow ret ps in t where - arrow t1 t2 = - RetType [] $ Scalar $ Arrow mempty Unnamed (toStruct t1) t2 + arrow (d, t1) t2 = + RetType [] $ Scalar $ Arrow mempty Unnamed d t1' t2 + where + t1' = toStruct t1 `setUniqueness` Nonunique + +foldFunTypeFromParams :: + Monoid as => + [PatBase Info VName] -> + RetTypeBase Size as -> + TypeBase Size as +foldFunTypeFromParams params = + foldFunType (zip (map diet params_ts) params_ts) + where + params_ts = map patternStructType params -- | Extract the parameter types and return type from a type. -- If the type is not an arrow type, the list of parameter types is empty. -unfoldFunType :: TypeBase dim as -> ([TypeBase dim ()], TypeBase dim ()) -unfoldFunType (Scalar (Arrow _ _ t1 (RetType _ t2))) = +unfoldFunType :: TypeBase dim as -> ([(Diet, TypeBase dim ())], TypeBase dim ()) +unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = let (ps, r) = unfoldFunType t2 - in (t1 : ps, r) + in ((d, t1) : ps, r) unfoldFunType t = ([], toStruct t) -- | The type scheme of a value binding, comprising the type @@ -547,14 +571,6 @@ valBindBound vb = [] -> retDims (unInfo (valBindRetType vb)) _ -> [] --- | The type of a function with the given parameters and return type. -funType :: [PatBase Info VName] -> StructRetType -> StructType -funType params ret = - let RetType _ t = foldr (arrow . patternParam) ret params - in t - where - arrow (xp, xt) yt = RetType [] $ Scalar $ Arrow () xp xt yt - -- | The type names mentioned in a type. typeVars :: Monoid as => TypeBase dim as -> S.Set VName typeVars t = @@ -562,7 +578,7 @@ typeVars t = Scalar Prim {} -> mempty Scalar (TypeVar _ _ tn targs) -> mconcat $ S.singleton (qualLeaf tn) : map typeArgFree targs - Scalar (Arrow _ _ t1 (RetType _ t2)) -> typeVars t1 <> typeVars t2 + Scalar (Arrow _ _ _ t1 (RetType _ t2)) -> typeVars t1 <> typeVars t2 Scalar (Record fields) -> foldMap typeVars fields Scalar (Sum cs) -> mconcat $ (foldMap . fmap) typeVars cs Array _ _ _ rt -> typeVars $ Scalar rt @@ -644,17 +660,19 @@ patternStructType = toStruct . patternType -- | When viewed as a function parameter, does this pattern correspond -- to a named parameter of some type? -patternParam :: PatBase Info VName -> (PName, StructType) +patternParam :: PatBase Info VName -> (PName, Diet, StructType) patternParam (PatParens p _) = patternParam p patternParam (PatAttr _ p _) = patternParam p patternParam (PatAscription (Id v (Info t) _) _ _) = - (Named v, toStruct t) + (Named v, diet t, toStruct t) patternParam (Id v (Info t) _) = - (Named v, toStruct t) + (Named v, diet t, toStruct t) patternParam p = - (Unnamed, patternStructType p) + (Unnamed, diet p_t, p_t) + where + p_t = patternStructType p -- | Names of primitive types to types. This is only valid if no -- shadowing is going on, but useful for tools. @@ -677,7 +695,7 @@ namesToPrimTypes = data Intrinsic = IntrinsicMonoFun [PrimType] PrimType | IntrinsicOverloadedFun [PrimType] [Maybe PrimType] (Maybe PrimType) - | IntrinsicPolyFun [TypeParamBase VName] [StructType] (RetTypeBase Size ()) + | IntrinsicPolyFun [TypeParamBase VName] [(Diet, StructType)] (RetTypeBase Size ()) | IntrinsicType Liftedness [TypeParamBase VName] StructType | IntrinsicEquality -- Special cased. @@ -732,16 +750,16 @@ intrinsics = ++ [ ( "flatten", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [Array () Nonunique (shape [n, m]) t_a] + [(Observe, Array () Nonunique (shape [n, m]) t_a)] $ RetType [k] $ Array () Nonunique (shape [k]) t_a ), ( "unflatten", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar $ Prim $ Signed Int64, - Scalar $ Prim $ Signed Int64, - Array () Nonunique (shape [n]) t_a + [ (Observe, Scalar $ Prim $ Signed Int64), + (Observe, Scalar $ Prim $ Signed Int64), + (Observe, Array () Nonunique (shape [n]) t_a) ] $ RetType [k, m] $ Array () Nonunique (shape [k, m]) t_a @@ -749,33 +767,37 @@ intrinsics = ( "concat", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [arr_a $ shape [n], arr_a $ shape [m]] + [ (Observe, array_a $ shape [n]), + (Observe, array_a $ shape [m]) + ] $ RetType [k] - $ uarr_a + $ uarray_a $ shape [k] ), ( "rotate", IntrinsicPolyFun [tp_a, sp_n] - [Scalar $ Prim $ Signed Int64, arr_a $ shape [n]] + [ (Observe, Scalar $ Prim $ Signed Int64), + (Observe, array_a $ shape [n]) + ] $ RetType [] - $ arr_a + $ array_a $ shape [n] ), ( "transpose", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [arr_a $ shape [n, m]] + [(Observe, array_a $ shape [n, m])] $ RetType [] - $ arr_a + $ array_a $ shape [m, n] ), ( "scatter", IntrinsicPolyFun [tp_a, sp_n, sp_l] - [ Array () Unique (shape [n]) t_a, - Array () Nonunique (shape [l]) (Prim $ Signed Int64), - Array () Nonunique (shape [l]) t_a + [ (Consume, Array () Unique (shape [n]) t_a), + (Observe, Array () Nonunique (shape [l]) (Prim $ Signed Int64)), + (Observe, Array () Nonunique (shape [l]) t_a) ] $ RetType [] $ Array () Unique (shape [n]) t_a @@ -783,98 +805,100 @@ intrinsics = ( "scatter_2d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_l] - [ uarr_a $ shape [n, m], - Array () Nonunique (shape [l]) (tupInt64 2), - Array () Nonunique (shape [l]) t_a + [ (Consume, uarray_a $ shape [n, m]), + (Observe, Array () Nonunique (shape [l]) (tupInt64 2)), + (Observe, Array () Nonunique (shape [l]) t_a) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n, m] ), ( "scatter_3d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k, sp_l] - [ uarr_a $ shape [n, m, k], - Array () Nonunique (shape [l]) (tupInt64 3), - Array () Nonunique (shape [l]) t_a + [ (Consume, uarray_a $ shape [n, m, k]), + (Observe, Array () Nonunique (shape [l]) (tupInt64 3)), + (Observe, Array () Nonunique (shape [l]) t_a) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n, m, k] ), ( "zip", IntrinsicPolyFun [tp_a, tp_b, sp_n] - [arr_a (shape [n]), arr_b (shape [n])] + [ (Observe, array_a (shape [n])), + (Observe, array_b (shape [n])) + ] $ RetType [] - $ tuple_uarr (Scalar t_a) (Scalar t_b) + $ tuple_uarray (Scalar t_a) (Scalar t_b) $ shape [n] ), ( "unzip", IntrinsicPolyFun [tp_a, tp_b, sp_n] - [tuple_arr (Scalar t_a) (Scalar t_b) $ shape [n]] + [(Observe, tuple_arr (Scalar t_a) (Scalar t_b) $ shape [n])] $ RetType [] . Scalar . Record . M.fromList - $ zip tupleFieldNames [arr_a $ shape [n], arr_b $ shape [n]] + $ zip tupleFieldNames [array_a $ shape [n], array_b $ shape [n]] ), ( "hist_1d", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [ Scalar $ Prim $ Signed Int64, - uarr_a $ shape [m], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Array () Nonunique (shape [n]) (tupInt64 1), - arr_a (shape [n]) + [ (Consume, Scalar $ Prim $ Signed Int64), + (Observe, uarray_a $ shape [m]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 1)), + (Observe, array_a (shape [n])) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [m] ), ( "hist_2d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k] - [ Scalar $ Prim $ Signed Int64, - uarr_a $ shape [m, k], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Array () Nonunique (shape [n]) (tupInt64 2), - arr_a (shape [n]) + [ (Observe, Scalar $ Prim $ Signed Int64), + (Consume, uarray_a $ shape [m, k]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 2)), + (Observe, array_a (shape [n])) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [m, k] ), ( "hist_3d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k, sp_l] - [ Scalar $ Prim $ Signed Int64, - uarr_a $ shape [m, k, l], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Array () Nonunique (shape [n]) (tupInt64 3), - arr_a (shape [n]) + [ (Observe, Scalar $ Prim $ Signed Int64), + (Consume, uarray_a $ shape [m, k, l]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 3)), + (Observe, array_a (shape [n])) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [m, k, l] ), ( "map", IntrinsicPolyFun [tp_a, tp_b, sp_n] - [ Scalar t_a `arr` Scalar t_b, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, array_a $ shape [n]) ] $ RetType [] - $ uarr_b + $ uarray_b $ shape [n] ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) ] $ RetType [] $ Scalar t_a @@ -882,9 +906,9 @@ intrinsics = ( "reduce_comm", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) ] $ RetType [] $ Scalar t_a @@ -892,24 +916,24 @@ intrinsics = ( "scan", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ), ( "partition", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar (Prim $ Signed Int32), - Scalar t_a `arr` Scalar (Prim $ Signed Int64), - arr_a $ shape [n] + [ (Observe, Scalar (Prim $ Signed Int32)), + (Observe, Scalar t_a `arr` Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [n]) ] ( RetType [m] . Scalar $ tupleRecord - [ uarr_a $ shape [k], + [ uarray_a $ shape [k], Array () Unique (shape [n]) (Prim $ Signed Int64) ] ) @@ -917,48 +941,52 @@ intrinsics = ( "acc_write", IntrinsicPolyFun [sp_k, tp_a] - [ Scalar $ accType arr_ka, - Scalar (Prim $ Signed Int64), - Scalar t_a + [ (Consume, Scalar $ accType array_ka), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar t_a) ] $ RetType [] $ Scalar - $ accType arr_ka + $ accType array_ka ), ( "scatter_stream", IntrinsicPolyFun [tp_a, tp_b, sp_k, sp_n] - [ uarr_ka, - Scalar (accType arr_ka) - `arr` ( Scalar t_b - `arr` Scalar (accType $ arr_a $ shape [k]) - ), - arr_b $ shape [n] + [ (Consume, uarray_ka), + ( Observe, + Scalar (accType array_ka) + `carr` ( Scalar t_b + `arr` Scalar (accType $ array_a $ shape [k]) + ) + ), + (Observe, array_b $ shape [n]) ] - $ RetType [] uarr_ka + $ RetType [] uarray_ka ), ( "hist_stream", IntrinsicPolyFun [tp_a, tp_b, sp_k, sp_n] - [ uarr_a $ shape [k], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Scalar (accType arr_ka) - `arr` ( Scalar t_b - `arr` Scalar (accType $ arr_a $ shape [k]) - ), - arr_b $ shape [n] + [ (Consume, uarray_a $ shape [k]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + ( Observe, + Scalar (accType array_ka) + `carr` ( Scalar t_b + `arr` Scalar (accType $ array_a $ shape [k]) + ) + ), + (Observe, array_b $ shape [n]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [k] ), ( "jvp2", IntrinsicPolyFun [tp_a, tp_b] - [ Scalar t_a `arr` Scalar t_b, - Scalar t_a, - Scalar t_a + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, Scalar t_a), + (Observe, Scalar t_a) ] $ RetType [] $ Scalar @@ -967,9 +995,9 @@ intrinsics = ( "vjp2", IntrinsicPolyFun [tp_a, tp_b] - [ Scalar t_a `arr` Scalar t_b, - Scalar t_a, - Scalar t_b + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, Scalar t_a), + (Observe, Scalar t_b) ] $ RetType [] $ Scalar @@ -981,91 +1009,91 @@ intrinsics = [ ( "flat_index_2d", IntrinsicPolyFun [tp_a, sp_n] - [ arr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64) + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) ] $ RetType [m, k] - $ arr_a + $ array_a $ shape [m, k] ), ( "flat_update_2d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l] - [ uarr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - arr_a $ shape [k, l] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ), ( "flat_index_3d", IntrinsicPolyFun [tp_a, sp_n] - [ arr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64) + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) ] $ RetType [m, k, l] - $ arr_a + $ array_a $ shape [m, k, l] ), ( "flat_update_3d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l, sp_p] - [ uarr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - arr_a $ shape [k, l, p] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l, p]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ), ( "flat_index_4d", IntrinsicPolyFun [tp_a, sp_n] - [ arr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64) + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) ] $ RetType [m, k, l, p] - $ arr_a + $ array_a $ shape [m, k, l, p] ), ( "flat_update_4d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l, sp_p, sp_q] - [ uarr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - arr_a $ shape [k, l, p, q] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l, p, q]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ) ] @@ -1073,13 +1101,13 @@ intrinsics = [a, b, n, m, k, l, p, q] = zipWith VName (map nameFromString ["a", "b", "n", "m", "k", "l", "p", "q"]) [0 ..] t_a = TypeVar () Nonunique (qualName a) [] - arr_a s = Array () Nonunique s t_a - uarr_a s = Array () Unique s t_a + array_a s = Array () Nonunique s t_a + uarray_a s = Array () Unique s t_a tp_a = TypeParamType Unlifted a mempty t_b = TypeVar () Nonunique (qualName b) [] - arr_b s = Array () Nonunique s t_b - uarr_b s = Array () Unique s t_b + array_b s = Array () Nonunique s t_b + uarray_b s = Array () Unique s t_b tp_b = TypeParamType Unlifted b mempty [sp_n, sp_m, sp_k, sp_l, sp_p, sp_q] = map (`TypeParamDim` mempty) [n, m, k, l, p, q] @@ -1092,12 +1120,13 @@ intrinsics = Nonunique s (Record (M.fromList $ zip tupleFieldNames [x, y])) - tuple_uarr x y s = tuple_arr x y s `setUniqueness` Unique + tuple_uarray x y s = tuple_arr x y s `setUniqueness` Unique - arr x y = Scalar $ Arrow mempty Unnamed x (RetType [] y) + arr x y = Scalar $ Arrow mempty Unnamed Observe x (RetType [] y) + carr x y = Scalar $ Arrow mempty Unnamed Consume x (RetType [] y) - arr_ka = Array () Nonunique (Shape [NamedSize $ qualName k]) t_a - uarr_ka = Array () Unique (Shape [NamedSize $ qualName k]) t_a + array_ka = Array () Nonunique (Shape [NamedSize $ qualName k]) t_a + uarray_ka = Array () Unique (Shape [NamedSize $ qualName k]) t_a accType t = TypeVar () Unique (qualName (fst intrinsicAcc)) [TypeArgType t mempty] diff --git a/src/Language/Futhark/Query.hs b/src/Language/Futhark/Query.hs index d703b758ad..06e0b02c8f 100644 --- a/src/Language/Futhark/Query.hs +++ b/src/Language/Futhark/Query.hs @@ -88,7 +88,7 @@ expDefs e = Lambda params _ _ _ _ -> mconcat (map patternDefs params) AppExp (LetFun name (tparams, params, _, Info ret, _) _ loc) _ -> - let name_t = foldFunType (map patternStructType params) ret + let name_t = foldFunType (map (undefined . patternStructType) params) ret in M.singleton name (DefBound $ BoundTerm name_t (locOf loc)) <> mconcat (map typeParamDefs tparams) <> mconcat (map patternDefs params) @@ -111,7 +111,7 @@ valBindDefs vbind = <> expDefs (valBindBody vbind) where vbind_t = - foldFunType (map patternStructType (valBindParams vbind)) $ + foldFunType (map (undefined . patternStructType) (valBindParams vbind)) $ unInfo $ valBindRetType vbind diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index fd060ac4b8..6ab3486366 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -300,7 +300,7 @@ data ScalarTypeBase dim as | Sum (M.Map Name [TypeBase dim as]) | -- | The aliasing corresponds to the lexical -- closure of the function. - Arrow as PName (TypeBase dim ()) (RetTypeBase dim as) + Arrow as PName Diet (TypeBase dim ()) (RetTypeBase dim as) deriving (Eq, Ord, Show) instance Bitraversable ScalarTypeBase where @@ -308,8 +308,8 @@ instance Bitraversable ScalarTypeBase where bitraverse f g (Record fs) = Record <$> traverse (bitraverse f g) fs bitraverse f g (TypeVar als u t args) = TypeVar <$> g als <*> pure u <*> pure t <*> traverse (traverse f) args - bitraverse f g (Arrow als v t1 t2) = - Arrow <$> g als <*> pure v <*> bitraverse f pure t1 <*> bitraverse f g t2 + bitraverse f g (Arrow als v d t1 t2) = + Arrow <$> g als <*> pure v <*> pure d <*> bitraverse f pure t1 <*> bitraverse f g t2 bitraverse f g (Sum cs) = Sum <$> (traverse . traverse) (bitraverse f g) cs instance Bifunctor ScalarTypeBase where @@ -456,21 +456,13 @@ instance Located (TypeArgExp vn) where locOf (TypeArgExpDim _ loc) = locOf loc locOf (TypeArgExpType t) = locOf t --- | Information about which parts of a value/type are consumed. +-- | Information about which parts of a parameter are consumed. This +-- can be considered kind of an effect on the function. data Diet - = -- | Consumes these fields in the record. - RecordDiet (M.Map Name Diet) - | -- | Consume these parts of the constructors. - SumDiet (M.Map Name [Diet]) - | -- | A function that consumes its argument(s) like this. - -- The final 'Diet' should always be 'Observe', as there - -- is no way for a function to consume its return value. - FuncDiet Diet Diet - | -- | Consumes this value. - Consume - | -- | Only observes value in this position, does - -- not consume. + = -- | Does not consume the parameter. Observe + | -- | Consumes the parameter. + Consume deriving (Eq, Ord, Show) -- | An identifier consists of its name and the type of the value diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index d113f2f571..1ee8f04032 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -301,10 +301,11 @@ traverseScalarType _ _ _ (Prim t) = pure $ Prim t traverseScalarType f g h (Record fs) = Record <$> traverse (traverseType f g h) fs traverseScalarType f g h (TypeVar als u t args) = TypeVar <$> h als <*> pure u <*> f t <*> traverse (traverseTypeArg f g) args -traverseScalarType f g h (Arrow als v t1 (RetType dims t2)) = +traverseScalarType f g h (Arrow als v u t1 (RetType dims t2)) = Arrow <$> h als <*> pure v + <*> pure u <*> traverseType f g pure t1 <*> (RetType dims <$> traverseType f g h t2) traverseScalarType f g h (Sum cs) = diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 7874be55f6..473281bb83 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -617,10 +617,10 @@ entryPoint params orig_ret_te (RetType ret orig_ret) = pname (Named v) = baseName v pname Unnamed = "_" - onRetType (Just (TEArrow p t1_te t2_te _)) (Scalar (Arrow _ _ t1 (RetType _ t2))) = + onRetType (Just (TEArrow p t1_te t2_te _)) (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = let (xs, y) = onRetType (Just t2_te) t2 in (EntryParam (maybe "_" baseName p) (EntryType t1 (Just t1_te)) : xs, y) - onRetType _ (Scalar (Arrow _ p t1 (RetType _ t2))) = + onRetType _ (Scalar (Arrow _ p _ t1 (RetType _ t2))) = let (xs, y) = onRetType Nothing t2 in (EntryParam (pname p) (EntryType t1 Nothing) : xs, y) onRetType te t = @@ -639,8 +639,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype withIndexLink "polymorphic-entry" "Entry point functions may not be polymorphic." - | not (all patternOrderZero params) - || not (all orderZero rettype_params) + | not (all orderZero param_ts) || not (orderZero rettype') = typeError loc mempty $ withIndexLink @@ -649,7 +648,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype | sizes_only_in_ret <- S.fromList (map typeParamName tparams) `S.intersection` freeInType rettype' - `S.difference` foldMap freeInType (map patternStructType params ++ rettype_params), + `S.difference` foldMap freeInType param_ts, not $ S.null sizes_only_in_ret = typeError loc mempty $ withIndexLink @@ -670,6 +669,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype where (RetType _ rettype_t) = rettype (rettype_params, rettype') = unfoldFunType rettype_t + param_ts = map patternStructType params ++ map snd rettype_params checkValBind :: ValBindBase NoInfo Name -> TypeM (Env, ValBind) checkValBind (ValBind entry fname maybe_tdecl NoInfo tparams params body doc attrs loc) = do @@ -705,9 +705,9 @@ nastyType t@Array {} = nastyType $ stripArray 1 t nastyType _ = True nastyReturnType :: Monoid als => Maybe (TypeExp VName) -> TypeBase dim als -> Bool -nastyReturnType Nothing (Scalar (Arrow _ _ t1 (RetType _ t2))) = +nastyReturnType Nothing (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = nastyType t1 || nastyReturnType Nothing t2 -nastyReturnType (Just (TEArrow _ te1 te2 _)) (Scalar (Arrow _ _ t1 (RetType _ t2))) = +nastyReturnType (Just (TEArrow _ te1 te2 _)) (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = (not (niceTypeExp te1) && nastyType t1) || nastyReturnType (Just te2) t2 nastyReturnType (Just te) _ diff --git a/src/Language/Futhark/TypeChecker/Modules.hs b/src/Language/Futhark/TypeChecker/Modules.hs index ea7cdc1065..b9a1d133af 100644 --- a/src/Language/Futhark/TypeChecker/Modules.hs +++ b/src/Language/Futhark/TypeChecker/Modules.hs @@ -149,8 +149,8 @@ newNamesForMTy orig_mty = do Scalar $ Sum $ (fmap . fmap) substituteInType ts substituteInType (Array () u shape t) = arrayOf u (substituteInShape shape) (substituteInType $ Scalar t) - substituteInType (Scalar (Arrow als v t1 (RetType dims t2))) = - Scalar $ Arrow als v (substituteInType t1) $ RetType dims $ substituteInType t2 + substituteInType (Scalar (Arrow als v d1 t1 (RetType dims t2))) = + Scalar $ Arrow als v d1 (substituteInType t1) $ RetType dims $ substituteInType t2 substituteInShape (Shape ds) = Shape $ map substituteInDim ds diff --git a/src/Language/Futhark/TypeChecker/Monad.hs b/src/Language/Futhark/TypeChecker/Monad.hs index c3d84c9896..a107a9eefb 100644 --- a/src/Language/Futhark/TypeChecker/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Monad.hs @@ -432,8 +432,8 @@ qualifyTypeVars outer_env orig_except ref_qs = onType (S.fromList orig_except) Record $ M.map (onType except) m onScalar except (Sum m) = Sum $ M.map (map $ onType except) m - onScalar except (Arrow as p t1 (RetType dims t2)) = - Arrow as p (onType except' t1) $ RetType dims (onType except' t2) + onScalar except (Arrow as p d t1 (RetType dims t2)) = + Arrow as p d (onType except' t1) $ RetType dims (onType except' t2) where except' = case p of Named p' -> S.insert p' except diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index da859e46d6..44cc4d25fe 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -211,10 +211,10 @@ checkApplyExp (AppExp (Apply e1 e2 _ loc) _) = do arg <- checkArg e2 (e1', (fname, i)) <- checkApplyExp e1 t <- expType e1' - (t1, rt, argext, exts) <- checkApply loc (fname, i) t arg + (d1, _, rt, argext, exts) <- checkApply loc (fname, i) t arg pure ( AppExp - (Apply e1' (argExp arg) (Info (diet t1, argext)) loc) + (Apply e1' (argExp arg) (Info (d1, argext)) loc) (Info $ AppRes rt exts), (fname, i + 1) ) @@ -347,8 +347,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (p1_t, rt, p1_ext, _) <- checkApply loc (Just op', 0) ftype e1_arg - (p2_t, rt', p2_ext, retext) <- checkApply loc (Just op', 1) rt e2_arg + (_, p1_t, rt, p1_ext, _) <- checkApply loc (Just op', 0) ftype e1_arg + (_, p2_t, rt', p2_ext, retext) <- checkApply loc (Just op', 1) rt e2_arg pure $ AppExp @@ -471,8 +471,7 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l bindSpaced [(Term, name)] $ do name' <- checkName Term name loc - let arrow (xp, xt) yt = RetType [] $ Scalar $ Arrow () xp xt yt - RetType _ ftype = foldr (arrow . patternParam) rettype params' + let ftype = funType params' rettype entry = BoundV Local tparams' $ ftype `setAliases` closure' bindF scope = scope @@ -627,12 +626,11 @@ checkExp (Lambda params body rettype_te NoInfo loc) = do -- are parameters, or are used in parameters. inferReturnSizes params' ret = do cur_lvl <- curLevel - let named (Named x, _) = Just x - named (Unnamed, _) = Nothing + let named (Named x, _, _) = Just x + named (Unnamed, _, _) = Nothing param_names = mapMaybe (named . patternParam) params' pos_sizes = - sizeNamesPos . foldFunType (map patternStructType params') $ - RetType [] ret + sizeNamesPos $ foldFunTypeFromParams params' $ RetType [] ret hide k (lvl, _) = lvl >= cur_lvl && k `notElem` param_names && k `S.notMember` pos_sizes @@ -650,9 +648,9 @@ checkExp (OpSection op _ loc) = do checkExp (OpSectionLeft op _ e _ _ loc) = do (op', ftype) <- lookupVar loc op e_arg <- checkArg e - (t1, rt, argext, retext) <- checkApply loc (Just op', 0) ftype e_arg + (_, t1, rt, argext, retext) <- checkApply loc (Just op', 0) ftype e_arg case (ftype, rt) of - (Scalar (Arrow _ m1 _ _), Scalar (Arrow _ m2 t2 rettype)) -> + (Scalar (Arrow _ m1 _ _ _), Scalar (Arrow _ m2 _ t2 rettype)) -> pure $ OpSectionLeft op' @@ -668,12 +666,12 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do (op', ftype) <- lookupVar loc op e_arg <- checkArg e case ftype of - Scalar (Arrow as1 m1 t1 (RetType [] (Scalar (Arrow as2 m2 t2 (RetType dims2 ret))))) -> do - (t2', ret', argext, _) <- + Scalar (Arrow as1 m1 d1 t1 (RetType [] (Scalar (Arrow as2 m2 d2 t2 (RetType dims2 ret))))) -> do + (_, t2', ret', argext, _) <- checkApply loc (Just op', 1) - (Scalar $ Arrow as2 m2 t2 $ RetType [] $ Scalar $ Arrow as1 m1 t1 $ RetType [] ret) + (Scalar $ Arrow as2 m2 d2 t2 $ RetType [] $ Scalar $ Arrow as1 m1 d1 t1 $ RetType [] ret) e_arg pure $ OpSectionRight @@ -690,13 +688,13 @@ checkExp (ProjectSection fields NoInfo loc) = do a <- newTypeVar loc "a" let usage = mkUsage loc "projection at" b <- foldM (flip $ mustHaveField usage) a fields - let ft = Scalar $ Arrow mempty Unnamed (toStruct a) $ RetType [] b + let ft = Scalar $ Arrow mempty Unnamed Observe (toStruct a) $ RetType [] b pure $ ProjectSection fields (Info ft) loc checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice (t, _) <- newArrayType loc "e" $ sliceDims slice' (t', retext) <- sliceShape Nothing slice' t - let ft = Scalar $ Arrow mempty Unnamed t $ RetType retext $ fromStruct t' + let ft = Scalar $ Arrow mempty Unnamed Observe t $ RetType retext $ fromStruct t' pure $ IndexSection slice' (Info ft) loc checkExp (AppExp (DoLoop _ mergepat mergeexp form loopbody loc) _) = do ((sparams, mergepat', mergeexp', form', loopbody'), appres) <- @@ -841,7 +839,7 @@ boundInsideType (Scalar (TypeVar _ _ _ targs)) = foldMap f targs f TypeArgDim {} = mempty boundInsideType (Scalar (Record fs)) = foldMap boundInsideType fs boundInsideType (Scalar (Sum cs)) = foldMap (foldMap boundInsideType) cs -boundInsideType (Scalar (Arrow _ pn t1 (RetType dims t2))) = +boundInsideType (Scalar (Arrow _ pn _ t1 (RetType dims t2))) = pn' <> boundInsideType t1 <> S.fromList dims <> boundInsideType t2 where pn' = case pn of @@ -863,11 +861,11 @@ checkApply :: ApplyOp -> PatType -> Arg -> - TermTypeM (StructType, PatType, Maybe VName, [VName]) + TermTypeM (Diet, StructType, PatType, Maybe VName, [VName]) checkApply loc (fname, _) - (Scalar (Arrow as pname tp1 tp2)) + (Scalar (Arrow as pname d1 tp1 tp2)) (argexp, argtype, dflow, argloc) = onFailure (CheckingApply fname argexp tp1 (toStruct argtype)) $ do expect (mkUsage argloc "use as function argument") (toStruct tp1) (toStruct argtype) @@ -896,14 +894,14 @@ checkApply in zeroOrderType (mkUsage argloc "potential consumption in expression") msg tp1 _ -> pure () - arg_consumed <- consumedByArg argloc argtype' (diet tp1') + arg_consumed <- consumedByArg (locOf argloc) argtype' d1 checkIfConsumable loc $ mconcat arg_consumed occur $ dflow `seqOccurrences` map (`consumption` argloc) arg_consumed -- Unification ignores uniqueness in higher-order arguments, so -- we check for that here. unless (toStructural argtype' `subtypeOf` setUniqueness (toStructural tp1') Nonunique) $ - typeError loc mempty "Consumption/aliasing does not match." + typeError loc mempty "Difference in whether argument is consumed." (argext, parsubst) <- case pname of @@ -923,18 +921,16 @@ checkApply v <- newID "internal_app_result" modify $ \s -> s {stateNames = M.insert v (NameAppRes fname loc) $ stateNames s} let appres = S.singleton $ AliasFree v - let tp2'' = applySubst parsubst $ returnType appres tp2' (diet tp1') argtype' + let tp2'' = applySubst parsubst $ returnType appres tp2' d1 argtype' - pure (tp1', tp2'', argext, ext) + pure (d1, tp1', tp2'', argext, ext) checkApply loc fname tfun@(Scalar TypeVar {}) arg = do tv <- newTypeVar loc "b" -- Change the uniqueness of the argument type because we never want -- to infer that a function is consuming. let argt_nonunique = toStruct (argType arg) `setUniqueness` Nonunique unify (mkUsage loc "use as function") (toStruct tfun) $ - Scalar $ - Arrow mempty Unnamed argt_nonunique $ - RetType [] tv + Scalar (Arrow mempty Unnamed Observe argt_nonunique $ RetType [] tv) tfun' <- normPatType tfun checkApply loc fname tfun' arg checkApply loc (fname, prev_applied) ftype (argexp, _, _, _) = do @@ -962,27 +958,21 @@ checkApply loc (fname, prev_applied) ftype (argexp, _, _, _) = do | prev_applied == 1 = "argument" | otherwise = "arguments" -consumedByArg :: SrcLoc -> PatType -> Diet -> TermTypeM [Aliasing] -consumedByArg loc (Scalar (Record ets)) (RecordDiet ds) = - mconcat . M.elems <$> traverse (uncurry $ consumedByArg loc) (M.intersectionWith (,) ets ds) -consumedByArg loc (Scalar (Sum ets)) (SumDiet ds) = - mconcat <$> traverse (uncurry $ consumedByArg loc) (concat $ M.elems $ M.intersectionWith zip ets ds) -consumedByArg loc (Scalar (Arrow _ _ t1 _)) (FuncDiet d _) - | not $ contravariantArg t1 d = - typeError loc mempty . withIndexLink "consuming-argument" $ - "Non-consuming higher-order parameter passed consuming argument." +aliasParts :: PatType -> [Aliasing] +aliasParts (Scalar (Record ts)) = foldMap aliasParts $ M.elems ts +aliasParts t = [aliases t] + +consumedByArg :: Loc -> PatType -> Diet -> TermTypeM [Aliasing] +consumedByArg loc at Consume = do + let parts = aliasParts at + foldM_ check mempty parts + pure parts where - contravariantArg (Array _ Unique _ _) Observe = - False - contravariantArg (Scalar (TypeVar _ Unique _ _)) Observe = - False - contravariantArg (Scalar (Record ets)) (RecordDiet ds) = - and (M.intersectionWith contravariantArg ets ds) - contravariantArg (Scalar (Arrow _ _ tp (RetType _ tr))) (FuncDiet dp dr) = - contravariantArg tp dp && contravariantArg tr dr - contravariantArg _ _ = - True -consumedByArg _ at Consume = pure [aliases at] + check seen als + | any (`S.member` seen) als = + typeError loc mempty . withIndexLink "self-aliasing-arg" $ + "Argument passed for consuming parameter is self-aliased." + | otherwise = pure $ als <> seen consumedByArg _ _ _ = pure [] -- | Type-check a single expression in isolation. This expression may @@ -1283,8 +1273,8 @@ hiddenParamNames :: [Pat] -> Names hiddenParamNames params = hidden where param_all_names = mconcat $ map patNames params - named (Named x, _) = Just x - named (Unnamed, _) = Nothing + named (Named x, _, _) = Just x + named (Unnamed, _, _) = Nothing param_names = S.fromList $ mapMaybe (named . patternParam) params hidden = param_all_names `S.difference` param_names @@ -1399,7 +1389,7 @@ checkBinding (fname, maybe_retdecl, tparams, params, body, loc) = -- | Extract all the shape names that occur in positive position -- (roughly, left side of an arrow) in a given type. sizeNamesPos :: TypeBase Size als -> S.Set VName -sizeNamesPos (Scalar (Arrow _ _ t1 (RetType _ t2))) = onParam t1 <> sizeNamesPos t2 +sizeNamesPos (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = onParam t1 <> sizeNamesPos t2 where onParam :: TypeBase Size als -> S.Set VName onParam (Scalar Arrow {}) = mempty @@ -1511,7 +1501,7 @@ verifyFunctionParams fname params = where forbidden' = case patternParam p of - (Named v, _) -> forbidden `S.difference` S.singleton v + (Named v, _, _) -> forbidden `S.difference` S.singleton v _ -> forbidden verifyParams _ [] = pure () @@ -1536,8 +1526,8 @@ injectExt ext ret = RetType ext_here $ deeper ret deeper (Scalar (Prim t)) = Scalar $ Prim t deeper (Scalar (Record fs)) = Scalar $ Record $ M.map deeper fs deeper (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map deeper) cs - deeper (Scalar (Arrow als p t1 (RetType t2_ext t2))) = - Scalar $ Arrow als p t1 $ injectExt (ext_there <> t2_ext) t2 + deeper (Scalar (Arrow als p d1 t1 (RetType t2_ext t2))) = + Scalar $ Arrow als p d1 t1 $ injectExt (ext_there <> t2_ext) t2 deeper (Scalar (TypeVar as u tn targs)) = Scalar $ TypeVar as u tn $ map deeperArg targs deeper t@Array {} = t @@ -1571,7 +1561,8 @@ closeOverTypes defname defloc tparams paramts ret substs = do injectExt (retext ++ mapMaybe mkExt (S.toList $ freeInType ret)) ret ) where - t = foldFunType paramts $ RetType [] ret + -- Diet does not matter here. + t = foldFunType (zip (repeat Observe) paramts) $ RetType [] ret to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs visible = typeVars t <> freeInType t diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 1411f30563..fd61dc452d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -677,9 +677,9 @@ instance MonadTypeChecker TermTypeM where argtype <- newTypeVar loc "t" equalityType usage argtype pure $ - Scalar . Arrow mempty Unnamed argtype . RetType [] $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ Scalar $ - Arrow mempty Unnamed argtype $ + Arrow mempty Unnamed Observe argtype $ RetType [] $ Scalar $ Prim Bool @@ -687,7 +687,7 @@ instance MonadTypeChecker TermTypeM where argtype <- newTypeVar loc "t" mustBeOneOf ts usage argtype let (pts', rt') = instOverloaded argtype pts rt - arrow xt yt = Scalar $ Arrow mempty Unnamed xt $ RetType [] yt + arrow xt yt = Scalar $ Arrow mempty Unnamed Observe xt $ RetType [] yt pure $ fromStruct $ foldr arrow rt' pts' observe $ Ident name (Info t) loc @@ -1002,7 +1002,7 @@ initialTermScope = initialVtable = M.fromList $ mapMaybe addIntrinsicF $ M.toList intrinsics prim = Scalar . Prim - arrow x y = Scalar $ Arrow mempty Unnamed x y + arrow x y = Scalar $ Arrow mempty Unnamed Observe x y addIntrinsicF (name, IntrinsicMonoFun pts t) = Just (name, BoundV Global [] $ arrow pts' $ RetType [] $ prim t) @@ -1015,15 +1015,8 @@ initialTermScope = addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = Just ( name, - BoundV Global tvs $ - fromStruct $ - Scalar $ - Arrow mempty Unnamed pts' rt + BoundV Global tvs $ fromStruct $ foldFunType pts rt ) - where - pts' = case pts of - [pt] -> pt - _ -> Scalar $ tupleRecord pts addIntrinsicF (name, IntrinsicEquality) = Just (name, EqualityF) addIntrinsicF _ = Nothing diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 62f39b056f..e1ad04143c 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -72,7 +72,7 @@ mustBeExplicitInBinding bind_t = M.fromList $ zip (S.toList $ freeInType ret) $ repeat True - in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty ts + in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty $ map snd ts where onType uses t = uses <> mustBeExplicitAux t -- Left-biased union. @@ -102,8 +102,8 @@ returnType appres (Scalar (TypeVar als Nonunique t targs)) d arg = Scalar $ TypeVar (appres <> als <> arg_als) Unique t targs where arg_als = aliases $ maskAliases arg d -returnType _ (Scalar (Arrow old_als v t1 (RetType dims t2))) d arg = - Scalar $ Arrow als v (t1 `setAliases` mempty) $ RetType dims $ t2 `setAliases` als +returnType _ (Scalar (Arrow old_als v pd t1 (RetType dims t2))) d arg = + Scalar $ Arrow als v pd (t1 `setAliases` mempty) $ RetType dims $ t2 `setAliases` als where -- Make sure to propagate the aliases of an existing closure. als = old_als <> aliases (maskAliases arg d) @@ -119,12 +119,6 @@ maskAliases :: TypeBase shape as maskAliases t Consume = t `setAliases` mempty maskAliases t Observe = t -maskAliases (Scalar (Record ets)) (RecordDiet ds) = - Scalar $ Record $ M.intersectionWith maskAliases ets ds -maskAliases (Scalar (Sum ets)) (SumDiet ds) = - Scalar $ Sum $ M.intersectionWith (zipWith maskAliases) ets ds -maskAliases t FuncDiet {} = t -maskAliases _ _ = error "Invalid arguments passed to maskAliases." -- | The two types are assumed to be structurally equal, but not -- necessarily regarding sizes. Combines aliases. @@ -140,9 +134,9 @@ addAliasesFromType (Scalar (Record ts1)) (Scalar (Record ts2)) sort (M.keys ts1) == sort (M.keys ts2) = Scalar $ Record $ M.intersectionWith addAliasesFromType ts1 ts2 addAliasesFromType - (Scalar (Arrow als1 mn1 pt1 (RetType dims1 rt1))) - (Scalar (Arrow als2 _ _ (RetType _ rt2))) = - Scalar (Arrow (als1 <> als2) mn1 pt1 (RetType dims1 rt1')) + (Scalar (Arrow als1 mn1 d1 pt1 (RetType dims1 rt1))) + (Scalar (Arrow als2 _ _ _ (RetType _ rt2))) = + Scalar (Arrow (als1 <> als2) mn1 d1 pt1 (RetType dims1 rt1')) where rt1' = addAliasesFromType rt1 rt2 addAliasesFromType (Scalar (Sum cs1)) (Scalar (Sum cs2)) @@ -203,9 +197,9 @@ unifyScalarTypes uf (Record ts1) (Record ts2) (M.intersectionWith (,) ts1 ts2) unifyScalarTypes uf - (Arrow as1 mn1 t1 (RetType dims1 t1')) - (Arrow as2 _ t2 (RetType _ t2')) = - Arrow (as1 <> as2) mn1 + (Arrow as1 mn1 d1 t1 (RetType dims1 t1')) + (Arrow as2 _ _ t2 (RetType _ t2')) = + Arrow (as1 <> as2) mn1 d1 <$> unifyTypesU (flip uf) t1 t2 <*> (RetType dims1 <$> unifyTypesU uf t1' t2') unifyScalarTypes uf (Sum cs1) (Sum cs2) @@ -339,7 +333,7 @@ evalTypeExp (TEArrow (Just v) t1 t2 loc) = do pure ( TEArrow (Just v') t1' t2' loc, svars1 ++ dims1 ++ svars2, - RetType [] $ Scalar $ Arrow mempty (Named v') st1 (RetType dims2 st2), + RetType [] $ Scalar $ Arrow mempty (Named v') (diet st1) st1 (RetType dims2 st2), Lifted ) -- @@ -349,7 +343,9 @@ evalTypeExp (TEArrow Nothing t1 t2 loc) = do pure ( TEArrow Nothing t1' t2' loc, svars1 ++ dims1 ++ svars2, - RetType [] $ Scalar $ Arrow mempty Unnamed st1 $ RetType dims2 st2, + RetType [] . Scalar $ + Arrow mempty Unnamed (diet st1) (st1 `setUniqueness` Nonunique) $ + RetType dims2 st2, Lifted ) -- @@ -741,8 +737,8 @@ substTypesRet lookupSubst ot = pure $ Scalar $ TypeVar als u v targs' onType (Scalar (Record ts)) = Scalar . Record <$> traverse onType ts - onType (Scalar (Arrow als v t1 t2)) = - Scalar <$> (Arrow als v <$> onType t1 <*> onRetType t2) + onType (Scalar (Arrow als v d t1 t2)) = + Scalar <$> (Arrow als v d <$> onType t1 <*> onRetType t2) onType (Scalar (Sum ts)) = Scalar . Sum <$> traverse (traverse onType) ts diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 7f87740a54..fc6a3c1445 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -482,15 +482,15 @@ unifyWith onDims usage = subunify False (_, Scalar (TypeVar _ _ (QualName [] v2) [])) | Just lvl <- nonrigid v2 -> link (not ord) v2 lvl t1' - ( Scalar (Arrow _ p1 a1 (RetType b1_dims b1)), - Scalar (Arrow _ p2 a2 (RetType b2_dims b2)) + ( Scalar (Arrow _ p1 d1 a1 (RetType b1_dims b1)), + Scalar (Arrow _ p2 d2 a2 (RetType b2_dims b2)) ) - | uncurry (<) $ swap ord (uniqueness a1) (uniqueness a2) -> do + | uncurry (<) $ swap ord d1 d2 -> do unifyError usage mempty bcs . withIndexLink "unify-consuming-param" $ - "Parameter types" - indent 2 (pretty a1) + "Parameters" + indent 2 (pretty d1 <> pretty a1) "and" - indent 2 (pretty a2) + indent 2 (pretty d2 <> pretty a2) "are incompatible regarding consuming their arguments." | otherwise -> do -- Introduce the existentials as size variables so they diff --git a/tests/accs/intrinsics.fut b/tests/accs/intrinsics.fut index 5b893d5cc1..4f6af17c5f 100644 --- a/tests/accs/intrinsics.fut +++ b/tests/accs/intrinsics.fut @@ -8,7 +8,7 @@ def scatter_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.scatter_stream (dest, f, bs) :> *[k]a + intrinsics.scatter_stream dest f bs :> *[k]a def reduce_by_index_stream [k] 'a 'b (dest: *[k]a) @@ -17,7 +17,7 @@ def reduce_by_index_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.hist_stream (dest, op, ne, f, bs) :> *[k]a + intrinsics.hist_stream dest op ne f bs :> *[k]a def write [n] 't (acc : *acc ([n]t)) (i: i64) (v: t) : *acc ([n]t) = - intrinsics.acc_write (acc, i, v) + intrinsics.acc_write acc i v diff --git a/tests/implicit_method.fut b/tests/implicit_method.fut index cac918909b..bf05be0f65 100644 --- a/tests/implicit_method.fut +++ b/tests/implicit_method.fut @@ -44,7 +44,7 @@ -- } -def tridagSeq [n][m] (a: [n]f32,b: *[m]f32,c: [m]f32,y: *[m]f32 ): *[m]f32 = +def tridagSeq [n][m] (a: [n]f32) (b: *[m]f32) (c: [m]f32) (y: *[m]f32 ): *[m]f32 = let (y,b) = loop ((y, b)) for i < n-1 do let i = i + 1 @@ -71,7 +71,8 @@ def implicitMethod [n][m] (myD: [m][3]f32, myDD: [m][3]f32, , dtInv - 0.5*(mu*d[1] + 0.5*var*dd[1]) , 0.0 - 0.5*(mu*d[2] + 0.5*var*dd[2]))) (zip4 (mu_row) (var_row) myD myDD)) - in tridagSeq( a, copy b, c, copy u_row )) (zip3 myMu myVar u) + in tridagSeq a (copy b) c (copy u_row)) + (zip3 myMu myVar u) def main [m][n] (myD: [m][3]f32) (myDD: [m][3]f32) (myMu: [n][m]f32) (myVar: [n][m]f32) diff --git a/tests/issue-1774.fut b/tests/issue1774.fut similarity index 100% rename from tests/issue-1774.fut rename to tests/issue1774.fut diff --git a/tests/migration/intrinsics.fut b/tests/migration/intrinsics.fut index f6ad2c3b98..f10742dcb0 100644 --- a/tests/migration/intrinsics.fut +++ b/tests/migration/intrinsics.fut @@ -1,20 +1,20 @@ def flat_index_2d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) : [n1][n2]a = - intrinsics.flat_index_2d(as, offset, n1, s1, n2, s2) :> [n1][n2]a + intrinsics.flat_index_2d as offset n1 s1 n2 s2 :> [n1][n2]a def flat_update_2d [n][k][l] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (asss: [k][l]a) : *[n]a = - intrinsics.flat_update_2d(as, offset, s1, s2, asss) + intrinsics.flat_update_2d as offset s1 s2 asss def flat_index_3d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) : [n1][n2][n3]a = - intrinsics.flat_index_3d(as, offset, n1, s1, n2, s2, n3, s3) :> [n1][n2][n3]a + intrinsics.flat_index_3d as offset n1 s1 n2 s2 n3 s3 :> [n1][n2][n3]a def flat_update_3d [n][k][l][p] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (asss: [k][l][p]a) : *[n]a = - intrinsics.flat_update_3d(as, offset, s1, s2, s3, asss) + intrinsics.flat_update_3d as offset s1 s2 s3 asss def flat_index_4d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) (n4: i64) (s4: i64) : [n1][n2][n3][n4]a = - intrinsics.flat_index_4d(as, offset, n1, s1, n2, s2, n3, s3, n4, s4) :> [n1][n2][n3][n4]a + intrinsics.flat_index_4d as offset n1 s1 n2 s2 n3 s3 n4 s4 :> [n1][n2][n3][n4]a def flat_update_4d [n][k][l][p][q] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (s4: i64) (asss: [k][l][p][q]a) : *[n]a = - intrinsics.flat_update_4d(as, offset, s1, s2, s3, s4, asss) + intrinsics.flat_update_4d as offset s1 s2 s3 s4 asss type~ acc 't = intrinsics.acc t @@ -23,7 +23,7 @@ def scatter_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.scatter_stream (dest, f, bs) :> *[k]a + intrinsics.scatter_stream dest f bs :> *[k]a def reduce_by_index_stream [k] 'a 'b (dest: *[k]a) @@ -32,7 +32,7 @@ def reduce_by_index_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.hist_stream (dest, op, ne, f, bs) :> *[k]a + intrinsics.hist_stream dest op ne f bs :> *[k]a def write [n] 't (acc : *acc ([n]t)) (i: i64) (v: t) : *acc ([n]t) = - intrinsics.acc_write (acc, i, v) \ No newline at end of file + intrinsics.acc_write acc i v diff --git a/tests/slice-lmads/intrinsics.fut b/tests/slice-lmads/intrinsics.fut index ac32880267..6ac2b46804 100644 --- a/tests/slice-lmads/intrinsics.fut +++ b/tests/slice-lmads/intrinsics.fut @@ -1,17 +1,17 @@ def flat_index_2d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) : [n1][n2]a = - intrinsics.flat_index_2d(as, offset, n1, s1, n2, s2) :> [n1][n2]a + intrinsics.flat_index_2d as offset n1 s1 n2 s2 :> [n1][n2]a def flat_update_2d [n][k][l] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (asss: [k][l]a) : *[n]a = - intrinsics.flat_update_2d(as, offset, s1, s2, asss) + intrinsics.flat_update_2d as offset s1 s2 asss def flat_index_3d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) : [n1][n2][n3]a = - intrinsics.flat_index_3d(as, offset, n1, s1, n2, s2, n3, s3) :> [n1][n2][n3]a + intrinsics.flat_index_3d as offset n1 s1 n2 s2 n3 s3 :> [n1][n2][n3]a def flat_update_3d [n][k][l][p] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (asss: [k][l][p]a) : *[n]a = - intrinsics.flat_update_3d(as, offset, s1, s2, s3, asss) + intrinsics.flat_update_3d as offset s1 s2 s3 asss def flat_index_4d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) (n4: i64) (s4: i64) : [n1][n2][n3][n4]a = - intrinsics.flat_index_4d(as, offset, n1, s1, n2, s2, n3, s3, n4, s4) :> [n1][n2][n3][n4]a + intrinsics.flat_index_4d as offset n1 s1 n2 s2 n3 s3 n4 s4 :> [n1][n2][n3][n4]a def flat_update_4d [n][k][l][p][q] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (s4: i64) (asss: [k][l][p][q]a) : *[n]a = - intrinsics.flat_update_4d(as, offset, s1, s2, s3, s4, asss) + intrinsics.flat_update_4d as offset s1 s2 s3 s4 asss diff --git a/tests/uniqueness/uniqueness-error23.fut b/tests/uniqueness/uniqueness-error23.fut index a951e993c4..8cd37999e9 100644 --- a/tests/uniqueness/uniqueness-error23.fut +++ b/tests/uniqueness/uniqueness-error23.fut @@ -1,13 +1,14 @@ -- == --- error: .*consumed.* +-- error: self-aliased def g(ar: *[]i64, a: *[][]i64): i64 = ar[0] def f(ar: *[]i64, a: *[][]i64): i64 = - g(a[0], a) -- Should be a type error, as both are supposed to be unique + g(a[0], a) -- Should be a type error, as both are supposed to be + -- unique yet they alias each other. def main(n: i64): i64 = - let a = copy(replicate n (iota n)) + let a = replicate n (iota n) let ar = copy(a[0]) in f(ar, a) diff --git a/unittests/Language/Futhark/SyntaxTests.hs b/unittests/Language/Futhark/SyntaxTests.hs index 3cb8b9a242..6440bdb61c 100644 --- a/unittests/Language/Futhark/SyntaxTests.hs +++ b/unittests/Language/Futhark/SyntaxTests.hs @@ -153,11 +153,18 @@ pScalarType :: Parser (ScalarTypeBase Size ()) pScalarType = choice [try pFun, pScalarNonFun] where pFun = - uncurry (Arrow ()) <$> pParam <* lexeme "->" <*> pStructRetType + pParam <* lexeme "->" <*> pStructRetType pParam = - choice [try pNamedParam, (Unnamed,) <$> pNonFunType] - pNamedParam = - parens $ (,) <$> (Named <$> pVName) <* lexeme ":" <*> pStructType + choice + [ try pNamedParam, + do + t <- pNonFunType + pure $ Arrow () Unnamed (diet t) t + ] + pNamedParam = parens $ do + v <- pVName <* lexeme ":" + t <- pStructType + pure $ Arrow () (Named v) (diet t) t pStructRetType :: Parser StructRetType pStructRetType =