diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d16d302b0..76360e0dbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Changed +* If part of a function parameter is marked as consuming ("unique"), + the *entire* parameter is now marked as consuming. + ### Fixed * A somewhat obscure simplification rule could mess up use of memory. diff --git a/docs/error-index.rst b/docs/error-index.rst index ab75d6dccd..d985650c2d 100644 --- a/docs/error-index.rst +++ b/docs/error-index.rst @@ -149,6 +149,41 @@ inserting copies to break the aliasing: def main (xs: *[]i32) : (*[]i32, *[]i32) = (xs, copy xs) +.. _self-aliasing-arg: + +"Argument passed for consuming parameter is self-aliased." +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Caused by programs like the following: + +.. code-block:: futhark + + def g (t: *([]i64, []i64)) = 0 + + def f n = + let x = iota n + in g (x,x) + +The function ``g`` expects to consume two separate ``[]i64`` arrays, +but ``f`` passes it a tuple containing two references to the same +physical array. This is not allowed, as ``g`` must be allowed to +assume that components of consuming record- or tuple parameters have +no internal aliases. We can fix this by inserting copies to break the +aliasing: + +.. code-block:: futhark + + def f n = + let x = iota n + in g (copy (x,x)) + +Alternative, we could duplicate the expression producing the array: + +.. code-block:: futhark + + def f n = + g (iota n, iota n)) + .. _consuming-parameter: "Consuming parameter passed non-unique argument" diff --git a/docs/language-reference.rst b/docs/language-reference.rst index ae6efbcc5e..19f3942e79 100644 --- a/docs/language-reference.rst +++ b/docs/language-reference.rst @@ -1434,15 +1434,27 @@ prefixing it with an asterisk. For a return type, we can mark it as def modify (a: *[]i32) (i: i32) (x: i32): *[]i32 = a with [i] = a[i] + x +A parameter that is not consuming is called *observing*. In the +parameter declaration ``a: *[i32]``, the asterisk means that the +function ``modify`` has been given "ownership" of the array ``a``, +meaning that any caller of ``modify`` will never reference array ``a`` +after the call again. This allows the ``with`` expression to perform +an in-place update. After a call ``modify a i x``, neither ``a`` or +any variable that *aliases* ``a`` may be used on any following +execution path. + +If an asterisk is present at *any point* inside a tuple parameter +type, the parameter as a whole is considered consuming. For example:: + + def consumes_both ((a,b): (*[]i32,[]i32)) = ... + +This is usually not desirable behaviour. Use multiple parameters +instead:: + + def consumes_first_arg (a: *[]i32) (b: []i32) = ... + For bulk in-place updates with multiple values, use the ``scatter`` -function in the basis library. In the parameter declaration ``a: -*[i32]``, the asterisk means that the function ``modify`` has been -given "ownership" of the array ``a``, meaning that any caller of -``modify`` will never reference array ``a`` after the call again. -This allows the ``with`` expression to perform an in-place update. - -After a call ``modify a i x``, neither ``a`` or any variable that -*aliases* ``a`` may be used on any following execution path. +function in the basis library. Alias Analysis ~~~~~~~~~~~~~~ diff --git a/prelude/ad.fut b/prelude/ad.fut index 91eb24ffbc..251e47a21d 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -7,12 +7,12 @@ -- | Jacobian-Vector Product ("forward mode"), producing also the -- primal result as the first element of the result tuple. let jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) = - intrinsics.jvp2 (f, x, x') + intrinsics.jvp2 f x x' -- | Vector-Jacobian Product ("reverse mode"), producing also the -- primal result as the first element of the result tuple. let vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = - intrinsics.vjp2 (f, x, y') + intrinsics.vjp2 f x y' -- | Jacobian-Vector Product ("forward mode"). let jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = diff --git a/prelude/array.fut b/prelude/array.fut index 7c880af8f1..b58046a4d7 100644 --- a/prelude/array.fut +++ b/prelude/array.fut @@ -62,7 +62,7 @@ def reverse [n] 't (x: [n]t): [n]t = x[::-1] -- **Work:** O(n). -- -- **Span:** O(1). -def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = intrinsics.concat (xs, ys) +def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = intrinsics.concat xs ys -- | An old-fashioned way of saying `++`. def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = xs ++ ys @@ -83,7 +83,7 @@ def concat_to [n] [m] 't (k: i64) (xs: [n]t) (ys: [m]t): *[k]t = xs ++ ys :> [k] -- -- Note: In most cases, `rotate` will be fused with subsequent -- operations such as `map`, in which case it is free. -def rotate [n] 't (r: i64) (xs: [n]t): [n]t = intrinsics.rotate (r, xs) +def rotate [n] 't (r: i64) (xs: [n]t): [n]t = intrinsics.rotate r xs -- | Construct an array of consecutive integers of the given length, -- starting at 0. @@ -143,7 +143,7 @@ def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): []t = -- -- **Complexity:** O(1). def unflatten [p] 't (n: i64) (m: i64) (xs: [p]t): [n][m]t = - intrinsics.unflatten (n, m, xs) :> [n][m]t + intrinsics.unflatten n m xs :> [n][m]t -- | Like `unflatten`, but produces three dimensions. def unflatten_3d [p] 't (n: i64) (m: i64) (l: i64) (xs: [p]t): [n][m][l]t = diff --git a/prelude/soacs.fut b/prelude/soacs.fut index 67a43b4fb6..fe7f632ccf 100644 --- a/prelude/soacs.fut +++ b/prelude/soacs.fut @@ -48,7 +48,7 @@ import "zip" -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x = - intrinsics.map (f, as) + intrinsics.map f as -- | Apply the given function to each element of a single array. -- @@ -104,7 +104,7 @@ def map5 'a 'b 'c 'd 'e [n] 'x (f: a -> b -> c -> d -> e -> x) (as: [n]a) (bs: [ -- Note that the complexity implies that parallelism in the combining -- operator will *not* be exploited. def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = - intrinsics.reduce (op, ne, as) + intrinsics.reduce op ne as -- | As `reduce`, but the operator must also be commutative. This is -- potentially faster than `reduce`. For simple built-in operators, @@ -115,7 +115,7 @@ def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = -- -- **Span:** *O(log(n) ✕ W(op))* def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = - intrinsics.reduce_comm (op, ne, as) + intrinsics.reduce_comm op ne as -- | `h = hist op ne k is as` computes a generalised `k`-bin histogram -- `h`, such that `h[i]` is the sum of those values `as[j]` for which @@ -130,7 +130,7 @@ def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a = -- -- In practice, linear span only occurs if *k* is also very large. def hist 'a [n] (op: a -> a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k]a = - intrinsics.hist_1d (1, map (\_ -> ne) (0..1.. ne) (0..1.. a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k -- -- In practice, linear span only occurs if *k* is also very large. def reduce_by_index 'a [k] [n] (dest : *[k]a) (f : a -> a -> a) (ne : a) (is : [n]i64) (as : [n]a) : *[k]a = - intrinsics.hist_1d (1, dest, f, ne, is, as) + intrinsics.hist_1d 1 dest f ne is as -- | As `reduce_by_index`, but with two-dimensional indexes. def reduce_by_index_2d 'a [k] [n] [m] (dest : *[k][m]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64)) (as : [n]a) : *[k][m]a = - intrinsics.hist_2d (1, dest, f, ne, is, as) + intrinsics.hist_2d 1 dest f ne is as -- | As `reduce_by_index`, but with three-dimensional indexes. def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64,i64)) (as : [n]a) : *[k][m][l]a = - intrinsics.hist_3d (1, dest, f, ne, is, as) + intrinsics.hist_3d 1 dest f ne is as -- | Inclusive prefix scan. Has the same caveats with respect to -- associativity and complexity as `reduce`. @@ -160,7 +160,7 @@ def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a) -- -- **Span:** *O(log(n) ✕ W(op))* def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a = - intrinsics.scan (op, ne, as) + intrinsics.scan op ne as -- | Remove all those elements of `as` that do not satisfy the -- predicate `p`. @@ -169,7 +169,7 @@ def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a = -- -- **Span:** *O(log(n) ✕ W(p))* def filter [n] 'a (p: a -> bool) (as: [n]a): *[]a = - let (as', is) = intrinsics.partition (1, \x -> if p x then 0 else 1, as) + let (as', is) = intrinsics.partition 1 (\x -> if p x then 0 else 1) as in as'[:is[0]] -- | Split an array into those elements that satisfy the given @@ -180,7 +180,7 @@ def filter [n] 'a (p: a -> bool) (as: [n]a): *[]a = -- **Span:** *O(log(n) ✕ W(p))* def partition [n] 'a (p: a -> bool) (as: [n]a): ([]a, []a) = let p' x = if p x then 0 else 1 - let (as', is) = intrinsics.partition (2, p', as) + let (as', is) = intrinsics.partition 2 p' as in (as'[0:is[0]], as'[is[0]:n]) -- | Split an array by two predicates, producing three arrays. @@ -190,7 +190,7 @@ def partition [n] 'a (p: a -> bool) (as: [n]a): ([]a, []a) = -- **Span:** *O(log(n) ✕ (W(p1) + W(p2)))* def partition2 [n] 'a (p1: a -> bool) (p2: a -> bool) (as: [n]a): ([]a, []a, []a) = let p' x = if p1 x then 0 else if p2 x then 1 else 2 - let (as', is) = intrinsics.partition (3, p', as) + let (as', is) = intrinsics.partition 3 p' as in (as'[0:is[0]], as'[is[0]:is[0]+is[1]], as'[is[0]+is[1]:n]) -- | Return `true` if the given function returns `true` for all @@ -223,7 +223,7 @@ def any [n] 'a (f: a -> bool) (as: [n]a): bool = -- -- **Span:** *O(1)* def spread 't [n] (k: i64) (x: t) (is: [n]i64) (vs: [n]t): *[k]t = - intrinsics.scatter (map (\_ -> x) (0..1.. x) (0..1.. x) (as: [n]a): [n]x = - intrinsics.map (f, as) + intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = - intrinsics.zip (as, bs) + intrinsics.zip as bs -- | Construct an array of pairs from two arrays. def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) = diff --git a/src/Futhark/Doc/Generator.hs b/src/Futhark/Doc/Generator.hs index 81e7fdcf1d..574acf158d 100644 --- a/src/Futhark/Doc/Generator.hs +++ b/src/Futhark/Doc/Generator.hs @@ -419,7 +419,7 @@ valBindHtml name (ValBind _ _ retdecl (Info rettype) tparams params _ _ _ _) = d map typeParamName tparams ++ map identName (S.toList $ mconcat $ map patIdents params) rettype' <- noLink' $ maybe (retTypeHtml rettype) typeExpHtml retdecl - params' <- noLink' $ mapM patternHtml params + params' <- noLink' $ mapM paramHtml params pure ( keyword "val " <> (H.span ! A.class_ "decl_name") name, tparams', @@ -493,6 +493,10 @@ synopsisValBindBind (name, BoundV tps t) = do <> ": " <> t' +dietHtml :: Diet -> Html +dietHtml Consume = "*" +dietHtml Observe = "" + typeHtml :: StructType -> DocM Html typeHtml t = case t of Array _ u shape et -> do @@ -513,14 +517,14 @@ typeHtml t = case t of targs' <- mapM typeArgHtml targs et' <- qualNameHtml et pure $ prettyU u <> et' <> mconcat (map (" " <>) targs') - Scalar (Arrow _ pname t1 t2) -> do + Scalar (Arrow _ pname d t1 t2) -> do t1' <- typeHtml t1 t2' <- retTypeHtml t2 pure $ case pname of Named v -> - parens (vnameHtml v <> ": " <> t1') <> " -> " <> t2' + parens (vnameHtml v <> ": " <> dietHtml d <> t1') <> " -> " <> t2' Unnamed -> - t1' <> " -> " <> t2' + dietHtml d <> t1' <> " -> " <> t2' Scalar (Sum cs) -> pipes <$> mapM ppClause (sortConstrs cs) where ppClause (n, ts) = joinBy " " . (ppConstr n :) <$> mapM typeHtml ts @@ -688,12 +692,12 @@ vnameLink' (VName _ tag) current file = then "#" ++ show tag else relativise file current ++ ".html#" ++ show tag -patternHtml :: Pat -> DocM Html -patternHtml pat = do - let (pat_param, t) = patternParam pat +paramHtml :: Pat -> DocM Html +paramHtml pat = do + let (pat_param, d, t) = patternParam pat t' <- typeHtml t pure $ case pat_param of - Named v -> parens (vnameHtml v <> ": " <> t') + Named v -> parens (vnameHtml v <> ": " <> dietHtml d <> t') Unnamed -> t' relativise :: FilePath -> FilePath -> FilePath diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 00bfa9fdff..36c5d20e05 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -384,7 +384,7 @@ defuncFun tparams pats e0 ret loc = do [pat'] -> (pat', ret, e0) (pat' : pats') -> ( pat', - RetType [] $ foldFunType (map (toStruct . patternType) pats') ret, + RetType [] $ funType pats' ret, Lambda pats' e0 Nothing (Info (mempty, ret)) loc ) @@ -533,15 +533,6 @@ defuncExp (AppExp (If e1 e2 e3 loc) res) = do (e2', sv) <- defuncExp e2 (e3', _) <- defuncExp e3 pure (AppExp (If e1' e2' e3' loc) res, sv) -defuncExp e@(AppExp (Apply f@(Var f' _ _) arg d loc) res) - | baseTag (qualLeaf f') <= maxIntrinsicTag, - TupLit es tuploc <- arg = do - -- defuncSoacExp also works fine for non-SOACs. - es' <- mapM defuncSoacExp es - pure - ( AppExp (Apply f (TupLit es' tuploc) d loc) res, - Dynamic $ typeOf e - ) defuncExp e@(AppExp Apply {} _) = defuncApply 0 e defuncExp (Negate e0 loc) = do (e0', sv) <- defuncExp e0 @@ -720,15 +711,19 @@ etaExpand e_t e = do (Info (AppRes (foldFunType argtypes ret) [])) ) e - $ zip3 vars (map snd ps) (drop 1 $ tails $ map snd ps) + $ zip3 vars (map (snd . snd) ps) (drop 1 $ tails $ map snd ps) pure (pats, e', second (const ()) ret) where - getType (RetType _ (Scalar (Arrow _ p t1 t2))) = - let (ps, r) = getType t2 in ((p, t1) : ps, r) + getType (RetType _ (Scalar (Arrow _ p d t1 t2))) = + let (ps, r) = getType t2 in ((p, (d, t1)) : ps, r) getType t = ([], t) - f prev (p, t) = do - let t' = fromStruct t + f prev (p, (d, t)) = do + let t' = + fromStruct t + `setUniqueness` case d of + Consume -> Unique + Observe -> Nonunique x <- case p of Named x | x `notElem` prev -> pure x _ -> newNameFromString "x" @@ -820,10 +815,9 @@ defuncApply :: Int -> Exp -> DefM (Exp, StaticVal) defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do let (argtypes, _) = unfoldFunType ret (e1', sv1) <- defuncApply (depth + 1) e1 - (e2', sv2) <- defuncExp e2 - let e' = AppExp (Apply e1' e2' d loc) t case sv1 of LambdaSV pat e0_t e0 closure_env -> do + (e2', sv2) <- defuncExp e2 let env' = matchPatSV pat sv2 dims = mempty (e0', sv) <- @@ -880,13 +874,15 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do let t1 = toStruct $ typeOf e1' t2 = toStruct $ typeOf e2' + d1 = Observe + d2 = Observe fname' = qualName fname fname'' = Var fname' ( Info - ( Scalar . Arrow mempty Unnamed t1 . RetType [] $ - Scalar . Arrow mempty Unnamed t2 $ + ( Scalar . Arrow mempty Unnamed d1 t1 . RetType [] $ + Scalar . Arrow mempty Unnamed d2 t2 $ RetType [] lifted_rettype ) ) @@ -896,7 +892,7 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do innercallret = AppRes - (Scalar $ Arrow mempty Unnamed t2 $ RetType [] lifted_rettype) + (Scalar $ Arrow mempty Unnamed d2 t2 $ RetType [] lifted_rettype) [] pure @@ -918,6 +914,7 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do -- but we update the types since it may be partially applied or return -- a higher-order term. DynamicFun _ sv -> do + (e2', _) <- defuncExp e2 let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) `setAliases` aliases ret callret = AppRes (combineTypeShapes ret restype) ext @@ -925,8 +922,14 @@ defuncApply depth e@(AppExp (Apply e1 e2 d loc) t@(Info (AppRes ret ext))) = do pure (apply_e, sv) -- Propagate the 'IntrinsicsSV' until we reach the outermost application, -- where we construct a dynamic static value with the appropriate type. - IntrinsicSV -> intrinsicOrHole argtypes e' sv1 - HoleSV {} -> intrinsicOrHole argtypes e' sv1 + IntrinsicSV -> do + e2' <- defuncSoacExp e2 + let e' = AppExp (Apply e1' e2' d loc) t + intrinsicOrHole argtypes e' sv1 + HoleSV {} -> do + (e2', _) <- defuncExp e2 + let e' = AppExp (Apply e1' e2' d loc) t + intrinsicOrHole argtypes e' sv1 _ -> error $ "Application of an expression\n" @@ -1129,9 +1132,10 @@ typeFromSV IntrinsicSV = -- | Construct the type for a fully-applied dynamic function from its -- static value and the original types of its arguments. -dynamicFunType :: StaticVal -> [StructType] -> ([PatType], PatType) +dynamicFunType :: StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType) dynamicFunType (DynamicFun _ sv) (p : ps) = - let (ps', ret) = dynamicFunType sv ps in (fromStruct p : ps', ret) + let (ps', ret) = dynamicFunType sv ps + in (second fromStruct p : ps', ret) dynamicFunType sv _ = ([], typeFromSV sv) -- | Match a pattern with its static value. Returns an environment with diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 1a45cfd2ba..031c304cc0 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1602,13 +1602,13 @@ isIntrinsicFunction qname args loc = do fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing - handleSOACs [TupLit [lam, arr] _] "map" = Just $ \desc -> do + handleSOACs [lam, arr] "map" = Just $ \desc -> do arr' <- internaliseExpToVars "map_arr" arr arr_ts <- mapM lookupType arr' lam' <- internaliseLambdaCoerce lam $ map rowType arr_ts let w = arraysSize 0 arr_ts letTupExp' desc $ I.Op $ I.Screma w arr' (I.mapSOAC lam') - handleSOACs [TupLit [k, lam, arr] _] "partition" = do + handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do arrs <- internaliseExpToVars "partition_input" arr @@ -1618,43 +1618,43 @@ isIntrinsicFunction qname args loc = do fromInt32 (Literal (SignedValue (Int32Value k')) _) = Just k' fromInt32 (IntLit k' (Info (E.Scalar (E.Prim (E.Signed Int32)))) _) = Just $ fromInteger k' fromInt32 _ = Nothing - handleSOACs [TupLit [lam, ne, arr] _] "reduce" = Just $ \desc -> + handleSOACs [lam, ne, arr] "reduce" = Just $ \desc -> internaliseScanOrReduce desc "reduce" reduce (lam, ne, arr, loc) where reduce w red_lam nes arrs = I.Screma w arrs <$> I.reduceSOAC [Reduce Noncommutative red_lam nes] - handleSOACs [TupLit [lam, ne, arr] _] "reduce_comm" = Just $ \desc -> + handleSOACs [lam, ne, arr] "reduce_comm" = Just $ \desc -> internaliseScanOrReduce desc "reduce" reduce (lam, ne, arr, loc) where reduce w red_lam nes arrs = I.Screma w arrs <$> I.reduceSOAC [Reduce Commutative red_lam nes] - handleSOACs [TupLit [lam, ne, arr] _] "scan" = Just $ \desc -> + handleSOACs [lam, ne, arr] "scan" = Just $ \desc -> internaliseScanOrReduce desc "scan" reduce (lam, ne, arr, loc) where reduce w scan_lam nes arrs = I.Screma w arrs <$> I.scanSOAC [Scan scan_lam nes] - handleSOACs [TupLit [rf, dest, op, ne, buckets, img] _] "hist_1d" = Just $ \desc -> + handleSOACs [rf, dest, op, ne, buckets, img] "hist_1d" = Just $ \desc -> internaliseHist 1 desc rf dest op ne buckets img loc - handleSOACs [TupLit [rf, dest, op, ne, buckets, img] _] "hist_2d" = Just $ \desc -> + handleSOACs [rf, dest, op, ne, buckets, img] "hist_2d" = Just $ \desc -> internaliseHist 2 desc rf dest op ne buckets img loc - handleSOACs [TupLit [rf, dest, op, ne, buckets, img] _] "hist_3d" = Just $ \desc -> + handleSOACs [rf, dest, op, ne, buckets, img] "hist_3d" = Just $ \desc -> internaliseHist 3 desc rf dest op ne buckets img loc handleSOACs _ _ = Nothing - handleAccs [TupLit [dest, f, bs] _] "scatter_stream" = Just $ \desc -> + handleAccs [dest, f, bs] "scatter_stream" = Just $ \desc -> internaliseStreamAcc desc dest Nothing f bs - handleAccs [TupLit [dest, op, ne, f, bs] _] "hist_stream" = Just $ \desc -> + handleAccs [dest, op, ne, f, bs] "hist_stream" = Just $ \desc -> internaliseStreamAcc desc dest (Just (op, ne)) f bs - handleAccs [TupLit [acc, i, v] _] "acc_write" = Just $ \desc -> do + handleAccs [acc, i, v] "acc_write" = Just $ \desc -> do acc' <- head <$> internaliseExpToVars "acc" acc i' <- internaliseExp1 "acc_i" i vs <- internaliseExp "acc_v" v fmap pure $ letSubExp desc $ BasicOp $ UpdateAcc acc' [i'] vs handleAccs _ _ = Nothing - handleAD [TupLit [f, x, v] _] fname + handleAD [f, x, v] fname | fname `elem` ["jvp2", "vjp2"] = Just $ \desc -> do x' <- internaliseExp "ad_x" x v' <- internaliseExp "ad_v" v @@ -1665,10 +1665,10 @@ isIntrinsicFunction qname args loc = do _ -> VJP lam x' v' handleAD _ _ = Nothing - handleRest [E.TupLit [a, si, v] _] "scatter" = Just $ scatterF 1 a si v - handleRest [E.TupLit [a, si, v] _] "scatter_2d" = Just $ scatterF 2 a si v - handleRest [E.TupLit [a, si, v] _] "scatter_3d" = Just $ scatterF 3 a si v - handleRest [E.TupLit [n, m, arr] _] "unflatten" = Just $ \desc -> do + handleRest [a, si, v] "scatter" = Just $ scatterF 1 a si v + handleRest [a, si, v] "scatter_2d" = Just $ scatterF 2 a si v + handleRest [a, si, v] "scatter_3d" = Just $ scatterF 3 a si v + handleRest [n, m, arr] "unflatten" = Just $ \desc -> do arrs <- internaliseExpToVars "unflatten_arr" arr n' <- internaliseExp1 "n" n m' <- internaliseExp1 "m" m @@ -1716,7 +1716,7 @@ isIntrinsicFunction qname args loc = do I.ReshapeArbitrary (reshapeOuter (I.Shape [k]) 2 $ I.arrayShape arr_t) arr' - handleRest [TupLit [x, y] _] "concat" = Just $ \desc -> do + handleRest [x, y] "concat" = Just $ \desc -> do xs <- internaliseExpToVars "concat_x" x ys <- internaliseExpToVars "concat_y" y outer_size <- arraysSize 0 <$> mapM lookupType xs @@ -1731,7 +1731,7 @@ isIntrinsicFunction qname args loc = do let conc xarr yarr = I.BasicOp $ I.Concat 0 (xarr :| [yarr]) ressize mapM (letSubExp desc) $ zipWith conc xs ys - handleRest [TupLit [offset, e] _] "rotate" = Just $ \desc -> do + handleRest [offset, e] "rotate" = Just $ \desc -> do offset' <- internaliseExp1 "rotation_offset" offset internaliseOperation desc e $ \v -> do r <- I.arrayRank <$> lookupType v @@ -1742,24 +1742,24 @@ isIntrinsicFunction qname args loc = do internaliseOperation desc e $ \v -> do r <- I.arrayRank <$> lookupType v pure $ I.Rearrange ([1, 0] ++ [2 .. r - 1]) v - handleRest [TupLit [x, y] _] "zip" = Just $ \desc -> + handleRest [x, y] "zip" = Just $ \desc -> mapM (letSubExp "zip_copy" . BasicOp . Copy) =<< ( (++) <$> internaliseExpToVars (desc ++ "_zip_x") x <*> internaliseExpToVars (desc ++ "_zip_y") y ) handleRest [x] "unzip" = Just $ flip internaliseExp x - handleRest [TupLit [arr, offset, n1, s1, n2, s2] _] "flat_index_2d" = Just $ \desc -> do + handleRest [arr, offset, n1, s1, n2, s2] "flat_index_2d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2)] - handleRest [TupLit [arr1, offset, s1, s2, arr2] _] "flat_update_2d" = Just $ \desc -> do + handleRest [arr1, offset, s1, s2, arr2] "flat_update_2d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2] arr2 - handleRest [TupLit [arr, offset, n1, s1, n2, s2, n3, s3] _] "flat_index_3d" = Just $ \desc -> do + handleRest [arr, offset, n1, s1, n2, s2, n3, s3] "flat_index_3d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2), (n3, s3)] - handleRest [TupLit [arr1, offset, s1, s2, s3, arr2] _] "flat_update_3d" = Just $ \desc -> do + handleRest [arr1, offset, s1, s2, s3, arr2] "flat_update_3d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2, s3] arr2 - handleRest [TupLit [arr, offset, n1, s1, n2, s2, n3, s3, n4, s4] _] "flat_index_4d" = Just $ \desc -> do + handleRest [arr, offset, n1, s1, n2, s2, n3, s3, n4, s4] "flat_index_4d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2), (n3, s3), (n4, s4)] - handleRest [TupLit [arr1, offset, s1, s2, s3, s4, arr2] _] "flat_update_4d" = Just $ \desc -> do + handleRest [arr1, offset, s1, s2, s3, s4, arr2] "flat_update_4d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2, s3, s4] arr2 handleRest _ _ = Nothing diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index b9e14dab2b..643639473a 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -221,7 +221,7 @@ transformFName loc fname t ( i - 1, AppExp (Apply f size_arg (Info (Observe, Nothing)) loc) - (Info $ AppRes (foldFunType (replicate i i64) (RetType [] (fromStruct t))) []) + (Info $ AppRes (foldFunType (replicate i (Observe, i64)) (RetType [] (fromStruct t))) []) ) applySizeArgs fname' t' size_args = @@ -233,7 +233,7 @@ transformFName loc fname t (qualName fname') ( Info ( foldFunType - (map (const i64) size_args) + (map (const (Observe, i64)) size_args) (RetType [] $ fromStruct t') ) ) @@ -540,7 +540,7 @@ desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (Ret (Info (Observe, xext)) loc ) - (Info $ AppRes (Scalar $ Arrow mempty yp ytype (RetType [] t)) []) + (Info $ AppRes (Scalar $ Arrow mempty yp Observe ytype (RetType [] t)) []) rettype' = let onDim (NamedSize d) | Named p <- xp, qualLeaf d == p = NamedSize $ qualName v1 @@ -580,7 +580,7 @@ desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (Ret pure (v, id, var_e, [pat]) desugarProjectSection :: [Name] -> PatType -> SrcLoc -> MonoM Exp -desugarProjectSection fields (Scalar (Arrow _ _ t1 (RetType dims t2))) loc = do +desugarProjectSection fields (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do p <- newVName "project_p" let body = foldl project (Var (qualName p) (Info t1') mempty) fields pure $ @@ -606,7 +606,7 @@ desugarProjectSection fields (Scalar (Arrow _ _ t1 (RetType dims t2))) loc = do desugarProjectSection _ t _ = error $ "desugarOpSection: not a function type: " ++ prettyString t desugarIndexSection :: [DimIndex] -> PatType -> SrcLoc -> MonoM Exp -desugarIndexSection idxs (Scalar (Arrow _ _ t1 (RetType dims t2))) loc = do +desugarIndexSection idxs (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do p <- newVName "index_i" let body = AppExp (Index (Var (qualName p) (Info t1') loc) idxs loc) (Info (AppRes t2 [])) pure $ @@ -719,8 +719,8 @@ noNamedParams = f where f (Array () u shape t) = Array () u shape (f' t) f (Scalar t) = Scalar $ f' t - f' (Arrow () _ t1 (RetType dims t2)) = - Arrow () Unnamed (f t1) (RetType dims (f t2)) + f' (Arrow () _ d1 t1 (RetType dims t2)) = + Arrow () Unnamed d1 (f t1) (RetType dims (f t2)) f' (Record fs) = Record $ fmap f fs f' (Sum cs) = @@ -737,7 +737,7 @@ monomorphiseBinding :: MonoM (VName, InferSizeArgs, ValBind) monomorphiseBinding entry (PolyBinding rr (name, tparams, params, rettype, body, attrs, loc)) inst_t = replaceRecordReplacements rr $ do - let bind_t = foldFunType (map patternStructType params) rettype + let bind_t = funType params rettype (substs, t_shape_params) <- typeSubstsM loc (noSizes bind_t) $ noNamedParams inst_t let substs' = M.map (Subst []) substs rettype' = applySubst (`M.lookup` substs') rettype @@ -840,7 +840,7 @@ typeSubstsM loc orig_t1 orig_t2 = (map snd $ sortFields fields1) (map snd $ sortFields fields2) sub (Scalar Prim {}) (Scalar Prim {}) = pure () - sub (Scalar (Arrow _ _ t1a (RetType _ t1b))) (Scalar (Arrow _ _ t2a t2b)) = do + sub (Scalar (Arrow _ _ _ t1a (RetType _ t1b))) (Scalar (Arrow _ _ _ t2a t2b)) = do sub t1a t2a subRet t1b t2b sub (Scalar (Sum cs1)) (Scalar (Sum cs2)) = @@ -940,11 +940,10 @@ transformValBind valbind = do Nothing -> pure () Just (Info entry) -> do t <- - removeTypeVariablesInType - $ foldFunType - (map patternStructType (valBindParams valbind)) - $ unInfo - $ valBindRetType valbind + removeTypeVariablesInType $ + funType (valBindParams valbind) $ + unInfo $ + valBindRetType valbind (name, infer, valbind'') <- monomorphiseBinding True valbind' $ monoType t entry' <- transformEntryPoint entry tell $ Seq.singleton (name, valbind'' {valBindEntryPoint = Just $ Info entry'}) diff --git a/src/Language/Futhark/FreeVars.hs b/src/Language/Futhark/FreeVars.hs index 5e8c0dfe83..76c89d73dd 100644 --- a/src/Language/Futhark/FreeVars.hs +++ b/src/Language/Futhark/FreeVars.hs @@ -148,7 +148,7 @@ freeInType t = mempty Scalar (Sum cs) -> foldMap (foldMap freeInType) cs - Scalar (Arrow _ v t1 (RetType dims t2)) -> + Scalar (Arrow _ v _ t1 (RetType dims t2)) -> S.filter (notV v) $ S.filter (`notElem` dims) $ freeInType t1 <> freeInType t2 Scalar (TypeVar _ _ _ targs) -> foldMap typeArgDims targs diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 5b3ee8f216..d430061039 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -154,8 +154,8 @@ resolveTypeParams names = match M.elems $ M.intersectionWith (zipWith match) poly_fields fields match - (Scalar (Arrow _ _ poly_t1 (RetType _ poly_t2))) - (Scalar (Arrow _ _ t1 (RetType _ t2))) = + (Scalar (Arrow _ _ _ poly_t1 (RetType _ poly_t2))) + (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = match poly_t1 t1 <> match poly_t2 t2 match poly_t t | d1 : _ <- shapeDims (arrayShape poly_t), @@ -547,8 +547,8 @@ evalIndex loc env is arr = do evalType :: Env -> StructType -> StructType evalType _ (Scalar (Prim pt)) = Scalar $ Prim pt evalType env (Scalar (Record fs)) = Scalar $ Record $ fmap (evalType env) fs -evalType env (Scalar (Arrow () p t1 (RetType dims t2))) = - Scalar $ Arrow () p (evalType env t1) (RetType dims (evalType env t2)) +evalType env (Scalar (Arrow () p d t1 (RetType dims t2))) = + Scalar $ Arrow () p d (evalType env t1) (RetType dims (evalType env t2)) evalType env t@(Array _ u shape _) = let et = stripArray (shapeRank shape) t et' = evalType env et @@ -611,7 +611,7 @@ evalFunction env _ [] body rettype = -- Eta-expand the rest to make any sizes visible. etaExpand [] env rettype where - etaExpand vs env' (Scalar (Arrow _ _ pt (RetType _ rt))) = + etaExpand vs env' (Scalar (Arrow _ _ _ pt (RetType _ rt))) = pure $ ValueFun $ \v -> do env'' <- matchPat env' (Wildcard (Info $ fromStruct pt) noLoc) v @@ -640,7 +640,7 @@ evalFunctionBinding :: EvalM TermBinding evalFunctionBinding env tparams ps ret fbody = do let ret' = evalType env $ retType ret - arrow (xp, xt) yt = Scalar $ Arrow () xp xt $ RetType [] yt + arrow (xp, d, xt) yt = Scalar $ Arrow () xp d xt $ RetType [] yt ftype = foldr (arrow . patternParam) ret' ps retext = case ps of [] -> retDims ret @@ -1231,57 +1231,61 @@ initialCtx = fun1 f = TermValue Nothing $ ValueFun $ \x -> f x + fun2 f = - TermValue Nothing $ - ValueFun $ \x -> - pure $ ValueFun $ \y -> f x y - fun2t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y] -> f x y - _ -> error $ "Expected pair; got: " <> show v - fun3t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z] -> f x y z - _ -> error $ "Expected triple; got: " <> show v - - fun5t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b] -> f x y z a b - _ -> error $ "Expected pentuple; got: " <> show v - - fun6t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c] -> f x y z a b c - _ -> error $ "Expected sextuple; got: " <> show v - - fun7t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c, d] -> f x y z a b c d - _ -> error $ "Expected septuple; got: " <> show v - - fun8t f = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c, d, e] -> f x y z a b c d e - _ -> error $ "Expected sextuple; got: " <> show v - - fun10t fun = - TermValue Nothing $ - ValueFun $ \v -> - case fromTuple v of - Just [x, y, z, a, b, c, d, e, f, g] -> fun x y z a b c d e f g - _ -> error $ "Expected octuple; got: " <> show v + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> f x y + + fun3 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> f x y z + + fun5 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> f x y z a b + + fun6 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> f x y z a b c + + fun7 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> + pure . ValueFun $ \d -> f x y z a b c d + + fun8 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> + pure . ValueFun $ \d -> + pure . ValueFun $ \e -> f x y z a b c d e + + fun10 f = + TermValue Nothing . ValueFun $ \x -> + pure . ValueFun $ \y -> + pure . ValueFun $ \z -> + pure . ValueFun $ \a -> + pure . ValueFun $ \b -> + pure . ValueFun $ \c -> + pure . ValueFun $ \d -> + pure . ValueFun $ \e -> + pure . ValueFun $ \g -> + pure . ValueFun $ \h -> f x y z a b c d e g h bopDef fs = fun2 $ \x y -> case (x, y) of @@ -1459,14 +1463,14 @@ initialCtx = _ -> error $ "Cannot unsign: " <> show x def s | "map_stream" `isPrefixOf` s = - Just $ fun2t stream + Just $ fun2 stream def s | "reduce_stream" `isPrefixOf` s = - Just $ fun3t $ \_ f arg -> stream f arg + Just $ fun3 $ \_ f arg -> stream f arg def "map" = Just $ TermPoly Nothing $ \t -> pure $ - ValueFun $ \v -> - case (fromTuple v, unfoldFunType t) of - (Just [f, xs], ([_], ret_t)) + ValueFun $ \f -> pure . ValueFun $ \xs -> + case unfoldFunType t of + ([_, _], ret_t) | Just rowshape <- typeRowShape ret_t -> toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) | otherwise -> @@ -1474,21 +1478,21 @@ initialCtx = _ -> error $ "Invalid arguments to map intrinsic:\n" - ++ unlines [prettyString t, show v] + ++ unlines [prettyString t, show f, show xs] where typeRowShape = sequenceA . structTypeShape mempty . stripArray 1 def s | "reduce" `isPrefixOf` s = Just $ - fun3t $ \f ne xs -> + fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs def "scan" = Just $ - fun3t $ \f ne xs -> do + fun3 $ \f ne xs -> do let next (out, acc) x = do x' <- apply2 noLoc mempty f acc x pure (x' : out, x') toArray' (valueShape ne) . reverse . fst <$> foldM next ([], ne) (snd $ fromArray xs) def "scatter" = Just $ - fun3t $ \arr is vs -> + fun3 $ \arr is vs -> case arr of ValueArray shape arr' -> pure $ @@ -1503,7 +1507,7 @@ initialCtx = then arr' // [(i, v)] else arr' def "scatter_2d" = Just $ - fun3t $ \arr is vs -> + fun3 $ \arr is vs -> case arr of ValueArray _ _ -> pure $ @@ -1518,7 +1522,7 @@ initialCtx = update _ _ = error "scatter_2d expects 2-dimensional indices" def "scatter_3d" = Just $ - fun3t $ \arr is vs -> + fun3 $ \arr is vs -> case arr of ValueArray _ _ -> pure $ @@ -1532,7 +1536,7 @@ initialCtx = fromMaybe arr $ writeArray (map (IndexingFix . asInt64) idxs) arr v update _ _ = error "scatter_3d expects 3-dimensional indices" - def "hist_1d" = Just . fun6t $ \_ arr fun _ is vs -> + def "hist_1d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr @@ -1541,7 +1545,7 @@ initialCtx = op = apply2 mempty mempty update fun arr (i, v) = fromMaybe arr <$> updateArray (op fun) [IndexingFix i] arr v - def "hist_2d" = Just . fun6t $ \_ arr fun _ is vs -> + def "hist_2d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr @@ -1553,7 +1557,7 @@ initialCtx = <$> updateArray (op fun) (map (IndexingFix . asInt64) idxs) arr v update _ _ _ = error "hist_2d: bad index value" - def "hist_3d" = Just . fun6t $ \_ arr fun _ is vs -> + def "hist_3d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr @@ -1566,7 +1570,7 @@ initialCtx = update _ _ _ = error "hist_2d: bad index value" def "partition" = Just $ - fun3t $ \k f xs -> do + fun3 $ \k f xs -> do let (ShapeDim _ rowshape, xs') = fromArray xs next outs x = do @@ -1586,7 +1590,7 @@ initialCtx = insertAt i x (l : ls) = l : insertAt (i - 1) x ls insertAt _ _ ls = ls def "scatter_stream" = Just $ - fun3t $ \dest f vs -> + fun3 $ \dest f vs -> case (dest, vs) of ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr @@ -1601,7 +1605,7 @@ initialCtx = _ -> error $ "scatter_stream expects array, but got: " <> prettyString (show vs, show vs) def "hist_stream" = Just $ - fun5t $ \dest op _ne f vs -> + fun5 $ \dest op _ne f vs -> case (dest, vs) of ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr @@ -1616,7 +1620,7 @@ initialCtx = _ -> error $ "hist_stream expects array, but got: " <> prettyString (show dest, show vs) def "acc_write" = Just $ - fun3t $ \acc i v -> + fun3 $ \acc i v -> case (acc, i) of ( ValueAcc op acc_arr, ValuePrim (SignedValue (Int64Value i')) @@ -1630,7 +1634,7 @@ initialCtx = _ -> error $ "acc_write invalid arguments: " <> prettyString (show acc, show i, show v) -- - def "flat_index_2d" = Just . fun6t $ \arr offset n1 s1 n2 s2 -> do + def "flat_index_2d" = Just . fun6 $ \arr offset n1 s1 n2 s2 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 @@ -1649,7 +1653,7 @@ initialCtx = bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1', s1'), (n2', s2'))] -- - def "flat_update_2d" = Just . fun5t $ \arr offset s1 s2 v -> do + def "flat_update_2d" = Just . fun5 $ \arr offset s1 s2 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 @@ -1666,7 +1670,7 @@ initialCtx = "Index out of bounds: " <> prettyText [((n1, s1'), (n2, s2'))] s -> error $ "flat_update_2d: invalid arg shape: " ++ show s -- - def "flat_index_3d" = Just . fun8t $ \arr offset n1 s1 n2 s2 n3 s3 -> do + def "flat_index_3d" = Just . fun8 $ \arr offset n1 s1 n2 s2 n3 s3 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 @@ -1688,7 +1692,7 @@ initialCtx = bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1', s1'), (n2', s2'), (n3', s3'))] -- - def "flat_update_3d" = Just . fun6t $ \arr offset s1 s2 s3 v -> do + def "flat_update_3d" = Just . fun6 $ \arr offset s1 s2 s3 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 @@ -1706,7 +1710,7 @@ initialCtx = "Index out of bounds: " <> prettyText [((n1, s1'), (n2, s2'), (n3, s3'))] s -> error $ "flat_update_3d: invalid arg shape: " ++ show s -- - def "flat_index_4d" = Just . fun10t $ \arr offset n1 s1 n2 s2 n3 s3 n4 s4 -> do + def "flat_index_4d" = Just . fun10 $ \arr offset n1 s1 n2 s2 n3 s3 n4 s4 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 @@ -1731,7 +1735,7 @@ initialCtx = bad mempty mempty $ "Index out of bounds: " <> prettyText [(((n1', s1'), (n2', s2')), ((n3', s3'), (n4', s4')))] -- - def "flat_update_4d" = Just . fun7t $ \arr offset s1 s2 s3 s4 v -> do + def "flat_update_4d" = Just . fun7 $ \arr offset s1 s2 s3 s4 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 @@ -1762,7 +1766,7 @@ initialCtx = fromPair (Just [x, y]) = (x, y) fromPair _ = error "Not a pair" def "zip" = Just $ - fun2t $ \xs ys -> do + fun2 $ \xs ys -> do let ShapeDim _ xs_rowshape = valueShape xs ShapeDim _ ys_rowshape = valueShape ys pure $ @@ -1770,7 +1774,7 @@ initialCtx = map toTuple $ transpose [snd $ fromArray xs, snd $ fromArray ys] def "concat" = Just $ - fun2t $ \xs ys -> do + fun2 $ \xs ys -> do let (ShapeDim _ rowshape, xs') = fromArray xs (_, ys') = fromArray ys pure $ toArray' rowshape $ xs' ++ ys' @@ -1784,7 +1788,7 @@ initialCtx = genericTake m $ transpose (map (snd . fromArray) xs') ++ repeat [] def "rotate" = Just $ - fun2t $ \i xs -> do + fun2 $ \i xs -> do let (shape, xs') = fromArray xs pure $ let idx = if null xs' then 0 else rem (asInt i) (length xs') @@ -1800,7 +1804,7 @@ initialCtx = let (ShapeDim n (ShapeDim m shape), xs') = fromArray xs pure $ toArray (ShapeDim (n * m) shape) $ concatMap (snd . fromArray) xs' def "unflatten" = Just $ - fun3t $ \n m xs -> do + fun3 $ \n m xs -> do let (ShapeDim xs_size innershape, xs') = fromArray xs rowshape = ShapeDim (asInt64 m) innershape shape = ShapeDim (asInt64 n) rowshape @@ -1816,10 +1820,10 @@ initialCtx = <> "]" else pure $ toArray shape $ map (toArray rowshape) $ chunk (asInt m) xs' def "vjp2" = Just $ - fun3t $ + fun3 $ \_ _ _ -> bad noLoc mempty "Interpreter does not support autodiff." def "jvp2" = Just $ - fun3t $ + fun3 $ \_ _ _ -> bad noLoc mempty "Interpreter does not support autodiff." def "acc" = Nothing def s | nameFromString s `M.member` namesToPrimTypes = Nothing @@ -1878,7 +1882,7 @@ valueType v = checkEntryArgs :: VName -> [V.Value] -> StructType -> Either T.Text () checkEntryArgs entry args entry_t - | args_ts == param_ts = + | args_ts == map snd param_ts = pure () | otherwise = Left . docText $ @@ -1916,9 +1920,9 @@ interpretFunction ctx fname vs = do f <- evalTermVar (ctxEnv ctx) (qualName fname) ft foldM (apply noLoc mempty) f vs' where - updateType (vt : vts) (Scalar (Arrow als u pt (RetType dims rt))) = do + updateType (vt : vts) (Scalar (Arrow als pn d pt (RetType dims rt))) = do checkInput vt pt - Scalar . Arrow als u (valueStructType vt) . RetType dims <$> updateType vts rt + Scalar . Arrow als pn d (valueStructType vt) . RetType dims <$> updateType vts rt updateType _ t = Right t diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index 99be6a3840..210afc3fe6 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -115,6 +115,10 @@ prettyRetType _ (RetType dims t) = instance Pretty (Shape dim) => Pretty (RetTypeBase dim as) where pretty = prettyRetType 0 +instance Pretty Diet where + pretty Consume = "*" + pretty Observe = "" + prettyScalarType :: Pretty (Shape dim) => Int -> ScalarTypeBase dim as -> Doc a prettyScalarType _ (Prim et) = pretty et prettyScalarType p (TypeVar _ u v targs) = @@ -128,11 +132,16 @@ prettyScalarType _ (Record fs) where ppField (name, t) = pretty (nameToString name) <> colon <+> align (pretty t) fs' = map ppField $ M.toList fs -prettyScalarType p (Arrow _ (Named v) t1 t2) = +prettyScalarType p (Arrow _ (Named v) d t1 t2) = + parensIf (p > 1) $ + parens (prettyName v <> colon <+> pretty d <> align (pretty t1)) + <+> "->" + <+> prettyRetType 1 t2 +prettyScalarType p (Arrow _ Unnamed d t1 t2) = parensIf (p > 1) $ - parens (prettyName v <> colon <+> align (pretty t1)) <+> "->" <+> prettyRetType 1 t2 -prettyScalarType p (Arrow _ Unnamed t1 t2) = - parensIf (p > 1) $ prettyType 2 t1 <+> "->" <+> prettyRetType 1 t2 + (pretty d <> prettyType 2 t1) + <+> "->" + <+> prettyRetType 1 t2 prettyScalarType p (Sum cs) = parensIf (p > 0) $ group (align (mconcat $ punctuate (" |" <> line) cs')) @@ -353,7 +362,7 @@ prettyExp p (Lambda params body rettype _ _) = parensIf (p /= -1) $ "\\" <> hsep (map pretty params) <> ppAscription rettype <+> "->" - indent 2 (pretty body) + indent 2 (align (pretty body)) prettyExp _ (OpSection binop _ _) = parens $ pretty binop prettyExp _ (OpSectionLeft binop _ x _ _ _) = diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 73a1c26ec4..de6d7be73b 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -48,6 +48,7 @@ module Language.Futhark.Prop orderZero, unfoldFunType, foldFunType, + foldFunTypeFromParams, typeVars, -- * Operations on types @@ -174,8 +175,8 @@ traverseDims f = go mempty PosImmediate Scalar . Sum <$> traverse (traverse (go bound b)) cs go _ _ (Scalar (Prim t)) = pure $ Scalar $ Prim t - go bound _ (Scalar (Arrow als p t1 (RetType dims t2))) = - Scalar <$> (Arrow als p <$> go bound' PosParam t1 <*> (RetType dims <$> go bound' PosReturn t2)) + go bound _ (Scalar (Arrow als p u t1 (RetType dims t2))) = + Scalar <$> (Arrow als p u <$> go bound' PosParam t1 <*> (RetType dims <$> go bound' PosReturn t2)) where bound' = S.fromList dims @@ -208,16 +209,16 @@ aliases :: Monoid as => TypeBase shape as -> as aliases = bifoldMap (const mempty) id -- | @diet t@ returns a description of how a function parameter of --- type @t@ might consume its argument. +-- type @t@ consumes its argument. diet :: TypeBase shape as -> Diet -diet (Scalar (Record ets)) = RecordDiet $ fmap diet ets +diet (Scalar (Record ets)) = foldl max Observe $ fmap diet ets diet (Scalar (Prim _)) = Observe -diet (Scalar (Arrow _ _ t1 (RetType _ t2))) = FuncDiet (diet t1) (diet t2) +diet (Scalar (Arrow {})) = Observe diet (Array _ Unique _ _) = Consume diet (Array _ Nonunique _ _) = Observe diet (Scalar (TypeVar _ Unique _ _)) = Consume diet (Scalar (TypeVar _ Nonunique _ _)) = Observe -diet (Scalar (Sum cs)) = SumDiet $ M.map (map diet) cs +diet (Scalar (Sum cs)) = foldl max Observe $ foldMap (map diet) cs -- | Convert any type to one that has rank information, no alias -- information, and no embedded names. @@ -334,8 +335,14 @@ combineTypeShapes (Scalar (Sum cs1)) (Scalar (Sum cs2)) M.map (uncurry $ zipWith combineTypeShapes) (M.intersectionWith (,) cs1 cs2) -combineTypeShapes (Scalar (Arrow als1 p1 a1 (RetType dims1 b1))) (Scalar (Arrow als2 _p2 a2 (RetType _ b2))) = - Scalar $ Arrow (als1 <> als2) p1 (combineTypeShapes a1 a2) (RetType dims1 (combineTypeShapes b1 b2)) +combineTypeShapes (Scalar (Arrow als1 p1 d1 a1 (RetType dims1 b1))) (Scalar (Arrow als2 _p2 _d2 a2 (RetType _ b2))) = + Scalar $ + Arrow + (als1 <> als2) + p1 + d1 + (combineTypeShapes a1 a2) + (RetType dims1 (combineTypeShapes b1 b2)) combineTypeShapes (Scalar (TypeVar als1 u1 v targs1)) (Scalar (TypeVar als2 _ _ targs2)) = Scalar $ TypeVar (als1 <> als2) u1 v $ zipWith f targs1 targs2 where @@ -385,12 +392,12 @@ matchDims onDims = matchDims' mempty <$> traverse (traverse (uncurry (matchDims' bound))) (M.intersectionWith zip cs1 cs2) - ( Scalar (Arrow als1 p1 a1 (RetType dims1 b1)), - Scalar (Arrow als2 p2 a2 (RetType dims2 b2)) + ( Scalar (Arrow als1 p1 d1 a1 (RetType dims1 b1)), + Scalar (Arrow als2 p2 _d2 a2 (RetType dims2 b2)) ) -> let bound' = mapMaybe paramName [p1, p2] <> dims1 <> dims2 <> bound in Scalar - <$> ( Arrow (als1 <> als2) p1 + <$> ( Arrow (als1 <> als2) p1 d1 <$> matchDims' bound' a1 a2 <*> (RetType dims1 <$> matchDims' bound' b1 b2) ) @@ -488,45 +495,62 @@ typeOf (Update e _ _ _) = typeOf e `setAliases` mempty typeOf (RecordUpdate _ _ _ (Info t) _) = t typeOf (Assert _ e _ _) = typeOf e typeOf (Lambda params _ _ (Info (als, t)) _) = - let RetType [] t' = foldr (arrow . patternParam) t params - in t' `setAliases` als - where - arrow (Named v, x) (RetType dims y) = - RetType [] $ Scalar $ Arrow () (Named v) x $ RetType (v : dims) y - arrow (pn, tx) y = - RetType [] $ Scalar $ Arrow () pn tx y + funType params t `setAliases` als typeOf (OpSection _ (Info t) _) = t typeOf (OpSectionLeft _ _ _ (_, Info (pn, pt2)) (Info ret, _) _) = - Scalar $ Arrow mempty pn pt2 ret + Scalar $ Arrow mempty pn Observe pt2 ret typeOf (OpSectionRight _ _ _ (Info (pn, pt1), _) (Info ret) _) = - Scalar $ Arrow mempty pn pt1 ret + Scalar $ Arrow mempty pn Observe pt1 ret typeOf (ProjectSection _ (Info t) _) = t typeOf (IndexSection _ (Info t) _) = t typeOf (Constr _ _ (Info t) _) = t typeOf (Attr _ e _) = typeOf e typeOf (AppExp _ (Info res)) = appResType res +-- | The type of a function with the given parameters and return type. +funType :: [PatBase Info VName] -> StructRetType -> StructType +funType params ret = + let RetType _ t = foldr (arrow . patternParam) ret params + in t + where + arrow (xp, d, xt) yt = + RetType [] $ Scalar $ Arrow () xp d xt' yt + where + xt' = xt `setUniqueness` Nonunique + -- | @foldFunType ts ret@ creates a function type ('Arrow') that takes -- @ts@ as parameters and returns @ret@. foldFunType :: Monoid as => - [TypeBase dim pas] -> + [(Diet, TypeBase dim pas)] -> RetTypeBase dim as -> TypeBase dim as foldFunType ps ret = let RetType _ t = foldr arrow ret ps in t where - arrow t1 t2 = - RetType [] $ Scalar $ Arrow mempty Unnamed (toStruct t1) t2 + arrow (d, t1) t2 = + RetType [] $ Scalar $ Arrow mempty Unnamed d t1' t2 + where + t1' = toStruct t1 `setUniqueness` Nonunique + +foldFunTypeFromParams :: + Monoid as => + [PatBase Info VName] -> + RetTypeBase Size as -> + TypeBase Size as +foldFunTypeFromParams params = + foldFunType (zip (map diet params_ts) params_ts) + where + params_ts = map patternStructType params -- | Extract the parameter types and return type from a type. -- If the type is not an arrow type, the list of parameter types is empty. -unfoldFunType :: TypeBase dim as -> ([TypeBase dim ()], TypeBase dim ()) -unfoldFunType (Scalar (Arrow _ _ t1 (RetType _ t2))) = +unfoldFunType :: TypeBase dim as -> ([(Diet, TypeBase dim ())], TypeBase dim ()) +unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = let (ps, r) = unfoldFunType t2 - in (t1 : ps, r) + in ((d, t1) : ps, r) unfoldFunType t = ([], toStruct t) -- | The type scheme of a value binding, comprising the type @@ -547,14 +571,6 @@ valBindBound vb = [] -> retDims (unInfo (valBindRetType vb)) _ -> [] --- | The type of a function with the given parameters and return type. -funType :: [PatBase Info VName] -> StructRetType -> StructType -funType params ret = - let RetType _ t = foldr (arrow . patternParam) ret params - in t - where - arrow (xp, xt) yt = RetType [] $ Scalar $ Arrow () xp xt yt - -- | The type names mentioned in a type. typeVars :: Monoid as => TypeBase dim as -> S.Set VName typeVars t = @@ -562,7 +578,7 @@ typeVars t = Scalar Prim {} -> mempty Scalar (TypeVar _ _ tn targs) -> mconcat $ S.singleton (qualLeaf tn) : map typeArgFree targs - Scalar (Arrow _ _ t1 (RetType _ t2)) -> typeVars t1 <> typeVars t2 + Scalar (Arrow _ _ _ t1 (RetType _ t2)) -> typeVars t1 <> typeVars t2 Scalar (Record fields) -> foldMap typeVars fields Scalar (Sum cs) -> mconcat $ (foldMap . fmap) typeVars cs Array _ _ _ rt -> typeVars $ Scalar rt @@ -644,17 +660,19 @@ patternStructType = toStruct . patternType -- | When viewed as a function parameter, does this pattern correspond -- to a named parameter of some type? -patternParam :: PatBase Info VName -> (PName, StructType) +patternParam :: PatBase Info VName -> (PName, Diet, StructType) patternParam (PatParens p _) = patternParam p patternParam (PatAttr _ p _) = patternParam p patternParam (PatAscription (Id v (Info t) _) _ _) = - (Named v, toStruct t) + (Named v, diet t, toStruct t) patternParam (Id v (Info t) _) = - (Named v, toStruct t) + (Named v, diet t, toStruct t) patternParam p = - (Unnamed, patternStructType p) + (Unnamed, diet p_t, p_t) + where + p_t = patternStructType p -- | Names of primitive types to types. This is only valid if no -- shadowing is going on, but useful for tools. @@ -677,7 +695,7 @@ namesToPrimTypes = data Intrinsic = IntrinsicMonoFun [PrimType] PrimType | IntrinsicOverloadedFun [PrimType] [Maybe PrimType] (Maybe PrimType) - | IntrinsicPolyFun [TypeParamBase VName] [StructType] (RetTypeBase Size ()) + | IntrinsicPolyFun [TypeParamBase VName] [(Diet, StructType)] (RetTypeBase Size ()) | IntrinsicType Liftedness [TypeParamBase VName] StructType | IntrinsicEquality -- Special cased. @@ -732,16 +750,16 @@ intrinsics = ++ [ ( "flatten", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [Array () Nonunique (shape [n, m]) t_a] + [(Observe, Array () Nonunique (shape [n, m]) t_a)] $ RetType [k] $ Array () Nonunique (shape [k]) t_a ), ( "unflatten", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar $ Prim $ Signed Int64, - Scalar $ Prim $ Signed Int64, - Array () Nonunique (shape [n]) t_a + [ (Observe, Scalar $ Prim $ Signed Int64), + (Observe, Scalar $ Prim $ Signed Int64), + (Observe, Array () Nonunique (shape [n]) t_a) ] $ RetType [k, m] $ Array () Nonunique (shape [k, m]) t_a @@ -749,33 +767,37 @@ intrinsics = ( "concat", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [arr_a $ shape [n], arr_a $ shape [m]] + [ (Observe, array_a $ shape [n]), + (Observe, array_a $ shape [m]) + ] $ RetType [k] - $ uarr_a + $ uarray_a $ shape [k] ), ( "rotate", IntrinsicPolyFun [tp_a, sp_n] - [Scalar $ Prim $ Signed Int64, arr_a $ shape [n]] + [ (Observe, Scalar $ Prim $ Signed Int64), + (Observe, array_a $ shape [n]) + ] $ RetType [] - $ arr_a + $ array_a $ shape [n] ), ( "transpose", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [arr_a $ shape [n, m]] + [(Observe, array_a $ shape [n, m])] $ RetType [] - $ arr_a + $ array_a $ shape [m, n] ), ( "scatter", IntrinsicPolyFun [tp_a, sp_n, sp_l] - [ Array () Unique (shape [n]) t_a, - Array () Nonunique (shape [l]) (Prim $ Signed Int64), - Array () Nonunique (shape [l]) t_a + [ (Consume, Array () Unique (shape [n]) t_a), + (Observe, Array () Nonunique (shape [l]) (Prim $ Signed Int64)), + (Observe, Array () Nonunique (shape [l]) t_a) ] $ RetType [] $ Array () Unique (shape [n]) t_a @@ -783,98 +805,100 @@ intrinsics = ( "scatter_2d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_l] - [ uarr_a $ shape [n, m], - Array () Nonunique (shape [l]) (tupInt64 2), - Array () Nonunique (shape [l]) t_a + [ (Consume, uarray_a $ shape [n, m]), + (Observe, Array () Nonunique (shape [l]) (tupInt64 2)), + (Observe, Array () Nonunique (shape [l]) t_a) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n, m] ), ( "scatter_3d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k, sp_l] - [ uarr_a $ shape [n, m, k], - Array () Nonunique (shape [l]) (tupInt64 3), - Array () Nonunique (shape [l]) t_a + [ (Consume, uarray_a $ shape [n, m, k]), + (Observe, Array () Nonunique (shape [l]) (tupInt64 3)), + (Observe, Array () Nonunique (shape [l]) t_a) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n, m, k] ), ( "zip", IntrinsicPolyFun [tp_a, tp_b, sp_n] - [arr_a (shape [n]), arr_b (shape [n])] + [ (Observe, array_a (shape [n])), + (Observe, array_b (shape [n])) + ] $ RetType [] - $ tuple_uarr (Scalar t_a) (Scalar t_b) + $ tuple_uarray (Scalar t_a) (Scalar t_b) $ shape [n] ), ( "unzip", IntrinsicPolyFun [tp_a, tp_b, sp_n] - [tuple_arr (Scalar t_a) (Scalar t_b) $ shape [n]] + [(Observe, tuple_arr (Scalar t_a) (Scalar t_b) $ shape [n])] $ RetType [] . Scalar . Record . M.fromList - $ zip tupleFieldNames [arr_a $ shape [n], arr_b $ shape [n]] + $ zip tupleFieldNames [array_a $ shape [n], array_b $ shape [n]] ), ( "hist_1d", IntrinsicPolyFun [tp_a, sp_n, sp_m] - [ Scalar $ Prim $ Signed Int64, - uarr_a $ shape [m], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Array () Nonunique (shape [n]) (tupInt64 1), - arr_a (shape [n]) + [ (Consume, Scalar $ Prim $ Signed Int64), + (Observe, uarray_a $ shape [m]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 1)), + (Observe, array_a (shape [n])) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [m] ), ( "hist_2d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k] - [ Scalar $ Prim $ Signed Int64, - uarr_a $ shape [m, k], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Array () Nonunique (shape [n]) (tupInt64 2), - arr_a (shape [n]) + [ (Observe, Scalar $ Prim $ Signed Int64), + (Consume, uarray_a $ shape [m, k]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 2)), + (Observe, array_a (shape [n])) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [m, k] ), ( "hist_3d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k, sp_l] - [ Scalar $ Prim $ Signed Int64, - uarr_a $ shape [m, k, l], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Array () Nonunique (shape [n]) (tupInt64 3), - arr_a (shape [n]) + [ (Observe, Scalar $ Prim $ Signed Int64), + (Consume, uarray_a $ shape [m, k, l]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 3)), + (Observe, array_a (shape [n])) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [m, k, l] ), ( "map", IntrinsicPolyFun [tp_a, tp_b, sp_n] - [ Scalar t_a `arr` Scalar t_b, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, array_a $ shape [n]) ] $ RetType [] - $ uarr_b + $ uarray_b $ shape [n] ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) ] $ RetType [] $ Scalar t_a @@ -882,9 +906,9 @@ intrinsics = ( "reduce_comm", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) ] $ RetType [] $ Scalar t_a @@ -892,24 +916,24 @@ intrinsics = ( "scan", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - arr_a $ shape [n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ), ( "partition", IntrinsicPolyFun [tp_a, sp_n] - [ Scalar (Prim $ Signed Int32), - Scalar t_a `arr` Scalar (Prim $ Signed Int64), - arr_a $ shape [n] + [ (Observe, Scalar (Prim $ Signed Int32)), + (Observe, Scalar t_a `arr` Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [n]) ] ( RetType [m] . Scalar $ tupleRecord - [ uarr_a $ shape [k], + [ uarray_a $ shape [k], Array () Unique (shape [n]) (Prim $ Signed Int64) ] ) @@ -917,48 +941,52 @@ intrinsics = ( "acc_write", IntrinsicPolyFun [sp_k, tp_a] - [ Scalar $ accType arr_ka, - Scalar (Prim $ Signed Int64), - Scalar t_a + [ (Consume, Scalar $ accType array_ka), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar t_a) ] $ RetType [] $ Scalar - $ accType arr_ka + $ accType array_ka ), ( "scatter_stream", IntrinsicPolyFun [tp_a, tp_b, sp_k, sp_n] - [ uarr_ka, - Scalar (accType arr_ka) - `arr` ( Scalar t_b - `arr` Scalar (accType $ arr_a $ shape [k]) - ), - arr_b $ shape [n] + [ (Consume, uarray_ka), + ( Observe, + Scalar (accType array_ka) + `carr` ( Scalar t_b + `arr` Scalar (accType $ array_a $ shape [k]) + ) + ), + (Observe, array_b $ shape [n]) ] - $ RetType [] uarr_ka + $ RetType [] uarray_ka ), ( "hist_stream", IntrinsicPolyFun [tp_a, tp_b, sp_k, sp_n] - [ uarr_a $ shape [k], - Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a), - Scalar t_a, - Scalar (accType arr_ka) - `arr` ( Scalar t_b - `arr` Scalar (accType $ arr_a $ shape [k]) - ), - arr_b $ shape [n] + [ (Consume, uarray_a $ shape [k]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + ( Observe, + Scalar (accType array_ka) + `carr` ( Scalar t_b + `arr` Scalar (accType $ array_a $ shape [k]) + ) + ), + (Observe, array_b $ shape [n]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [k] ), ( "jvp2", IntrinsicPolyFun [tp_a, tp_b] - [ Scalar t_a `arr` Scalar t_b, - Scalar t_a, - Scalar t_a + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, Scalar t_a), + (Observe, Scalar t_a) ] $ RetType [] $ Scalar @@ -967,9 +995,9 @@ intrinsics = ( "vjp2", IntrinsicPolyFun [tp_a, tp_b] - [ Scalar t_a `arr` Scalar t_b, - Scalar t_a, - Scalar t_b + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, Scalar t_a), + (Observe, Scalar t_b) ] $ RetType [] $ Scalar @@ -981,91 +1009,91 @@ intrinsics = [ ( "flat_index_2d", IntrinsicPolyFun [tp_a, sp_n] - [ arr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64) + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) ] $ RetType [m, k] - $ arr_a + $ array_a $ shape [m, k] ), ( "flat_update_2d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l] - [ uarr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - arr_a $ shape [k, l] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ), ( "flat_index_3d", IntrinsicPolyFun [tp_a, sp_n] - [ arr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64) + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) ] $ RetType [m, k, l] - $ arr_a + $ array_a $ shape [m, k, l] ), ( "flat_update_3d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l, sp_p] - [ uarr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - arr_a $ shape [k, l, p] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l, p]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ), ( "flat_index_4d", IntrinsicPolyFun [tp_a, sp_n] - [ arr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64) + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) ] $ RetType [m, k, l, p] - $ arr_a + $ array_a $ shape [m, k, l, p] ), ( "flat_update_4d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l, sp_p, sp_q] - [ uarr_a $ shape [n], - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - Scalar (Prim $ Signed Int64), - arr_a $ shape [k, l, p, q] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l, p, q]) ] $ RetType [] - $ uarr_a + $ uarray_a $ shape [n] ) ] @@ -1073,13 +1101,13 @@ intrinsics = [a, b, n, m, k, l, p, q] = zipWith VName (map nameFromString ["a", "b", "n", "m", "k", "l", "p", "q"]) [0 ..] t_a = TypeVar () Nonunique (qualName a) [] - arr_a s = Array () Nonunique s t_a - uarr_a s = Array () Unique s t_a + array_a s = Array () Nonunique s t_a + uarray_a s = Array () Unique s t_a tp_a = TypeParamType Unlifted a mempty t_b = TypeVar () Nonunique (qualName b) [] - arr_b s = Array () Nonunique s t_b - uarr_b s = Array () Unique s t_b + array_b s = Array () Nonunique s t_b + uarray_b s = Array () Unique s t_b tp_b = TypeParamType Unlifted b mempty [sp_n, sp_m, sp_k, sp_l, sp_p, sp_q] = map (`TypeParamDim` mempty) [n, m, k, l, p, q] @@ -1092,12 +1120,13 @@ intrinsics = Nonunique s (Record (M.fromList $ zip tupleFieldNames [x, y])) - tuple_uarr x y s = tuple_arr x y s `setUniqueness` Unique + tuple_uarray x y s = tuple_arr x y s `setUniqueness` Unique - arr x y = Scalar $ Arrow mempty Unnamed x (RetType [] y) + arr x y = Scalar $ Arrow mempty Unnamed Observe x (RetType [] y) + carr x y = Scalar $ Arrow mempty Unnamed Consume x (RetType [] y) - arr_ka = Array () Nonunique (Shape [NamedSize $ qualName k]) t_a - uarr_ka = Array () Unique (Shape [NamedSize $ qualName k]) t_a + array_ka = Array () Nonunique (Shape [NamedSize $ qualName k]) t_a + uarray_ka = Array () Unique (Shape [NamedSize $ qualName k]) t_a accType t = TypeVar () Unique (qualName (fst intrinsicAcc)) [TypeArgType t mempty] diff --git a/src/Language/Futhark/Query.hs b/src/Language/Futhark/Query.hs index d703b758ad..06e0b02c8f 100644 --- a/src/Language/Futhark/Query.hs +++ b/src/Language/Futhark/Query.hs @@ -88,7 +88,7 @@ expDefs e = Lambda params _ _ _ _ -> mconcat (map patternDefs params) AppExp (LetFun name (tparams, params, _, Info ret, _) _ loc) _ -> - let name_t = foldFunType (map patternStructType params) ret + let name_t = foldFunType (map (undefined . patternStructType) params) ret in M.singleton name (DefBound $ BoundTerm name_t (locOf loc)) <> mconcat (map typeParamDefs tparams) <> mconcat (map patternDefs params) @@ -111,7 +111,7 @@ valBindDefs vbind = <> expDefs (valBindBody vbind) where vbind_t = - foldFunType (map patternStructType (valBindParams vbind)) $ + foldFunType (map (undefined . patternStructType) (valBindParams vbind)) $ unInfo $ valBindRetType vbind diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index fd060ac4b8..6ab3486366 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -300,7 +300,7 @@ data ScalarTypeBase dim as | Sum (M.Map Name [TypeBase dim as]) | -- | The aliasing corresponds to the lexical -- closure of the function. - Arrow as PName (TypeBase dim ()) (RetTypeBase dim as) + Arrow as PName Diet (TypeBase dim ()) (RetTypeBase dim as) deriving (Eq, Ord, Show) instance Bitraversable ScalarTypeBase where @@ -308,8 +308,8 @@ instance Bitraversable ScalarTypeBase where bitraverse f g (Record fs) = Record <$> traverse (bitraverse f g) fs bitraverse f g (TypeVar als u t args) = TypeVar <$> g als <*> pure u <*> pure t <*> traverse (traverse f) args - bitraverse f g (Arrow als v t1 t2) = - Arrow <$> g als <*> pure v <*> bitraverse f pure t1 <*> bitraverse f g t2 + bitraverse f g (Arrow als v d t1 t2) = + Arrow <$> g als <*> pure v <*> pure d <*> bitraverse f pure t1 <*> bitraverse f g t2 bitraverse f g (Sum cs) = Sum <$> (traverse . traverse) (bitraverse f g) cs instance Bifunctor ScalarTypeBase where @@ -456,21 +456,13 @@ instance Located (TypeArgExp vn) where locOf (TypeArgExpDim _ loc) = locOf loc locOf (TypeArgExpType t) = locOf t --- | Information about which parts of a value/type are consumed. +-- | Information about which parts of a parameter are consumed. This +-- can be considered kind of an effect on the function. data Diet - = -- | Consumes these fields in the record. - RecordDiet (M.Map Name Diet) - | -- | Consume these parts of the constructors. - SumDiet (M.Map Name [Diet]) - | -- | A function that consumes its argument(s) like this. - -- The final 'Diet' should always be 'Observe', as there - -- is no way for a function to consume its return value. - FuncDiet Diet Diet - | -- | Consumes this value. - Consume - | -- | Only observes value in this position, does - -- not consume. + = -- | Does not consume the parameter. Observe + | -- | Consumes the parameter. + Consume deriving (Eq, Ord, Show) -- | An identifier consists of its name and the type of the value diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index d113f2f571..1ee8f04032 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -301,10 +301,11 @@ traverseScalarType _ _ _ (Prim t) = pure $ Prim t traverseScalarType f g h (Record fs) = Record <$> traverse (traverseType f g h) fs traverseScalarType f g h (TypeVar als u t args) = TypeVar <$> h als <*> pure u <*> f t <*> traverse (traverseTypeArg f g) args -traverseScalarType f g h (Arrow als v t1 (RetType dims t2)) = +traverseScalarType f g h (Arrow als v u t1 (RetType dims t2)) = Arrow <$> h als <*> pure v + <*> pure u <*> traverseType f g pure t1 <*> (RetType dims <$> traverseType f g h t2) traverseScalarType f g h (Sum cs) = diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 7874be55f6..473281bb83 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -617,10 +617,10 @@ entryPoint params orig_ret_te (RetType ret orig_ret) = pname (Named v) = baseName v pname Unnamed = "_" - onRetType (Just (TEArrow p t1_te t2_te _)) (Scalar (Arrow _ _ t1 (RetType _ t2))) = + onRetType (Just (TEArrow p t1_te t2_te _)) (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = let (xs, y) = onRetType (Just t2_te) t2 in (EntryParam (maybe "_" baseName p) (EntryType t1 (Just t1_te)) : xs, y) - onRetType _ (Scalar (Arrow _ p t1 (RetType _ t2))) = + onRetType _ (Scalar (Arrow _ p _ t1 (RetType _ t2))) = let (xs, y) = onRetType Nothing t2 in (EntryParam (pname p) (EntryType t1 Nothing) : xs, y) onRetType te t = @@ -639,8 +639,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype withIndexLink "polymorphic-entry" "Entry point functions may not be polymorphic." - | not (all patternOrderZero params) - || not (all orderZero rettype_params) + | not (all orderZero param_ts) || not (orderZero rettype') = typeError loc mempty $ withIndexLink @@ -649,7 +648,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype | sizes_only_in_ret <- S.fromList (map typeParamName tparams) `S.intersection` freeInType rettype' - `S.difference` foldMap freeInType (map patternStructType params ++ rettype_params), + `S.difference` foldMap freeInType param_ts, not $ S.null sizes_only_in_ret = typeError loc mempty $ withIndexLink @@ -670,6 +669,7 @@ checkEntryPoint loc tparams params maybe_tdecl rettype where (RetType _ rettype_t) = rettype (rettype_params, rettype') = unfoldFunType rettype_t + param_ts = map patternStructType params ++ map snd rettype_params checkValBind :: ValBindBase NoInfo Name -> TypeM (Env, ValBind) checkValBind (ValBind entry fname maybe_tdecl NoInfo tparams params body doc attrs loc) = do @@ -705,9 +705,9 @@ nastyType t@Array {} = nastyType $ stripArray 1 t nastyType _ = True nastyReturnType :: Monoid als => Maybe (TypeExp VName) -> TypeBase dim als -> Bool -nastyReturnType Nothing (Scalar (Arrow _ _ t1 (RetType _ t2))) = +nastyReturnType Nothing (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = nastyType t1 || nastyReturnType Nothing t2 -nastyReturnType (Just (TEArrow _ te1 te2 _)) (Scalar (Arrow _ _ t1 (RetType _ t2))) = +nastyReturnType (Just (TEArrow _ te1 te2 _)) (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = (not (niceTypeExp te1) && nastyType t1) || nastyReturnType (Just te2) t2 nastyReturnType (Just te) _ diff --git a/src/Language/Futhark/TypeChecker/Modules.hs b/src/Language/Futhark/TypeChecker/Modules.hs index ea7cdc1065..b9a1d133af 100644 --- a/src/Language/Futhark/TypeChecker/Modules.hs +++ b/src/Language/Futhark/TypeChecker/Modules.hs @@ -149,8 +149,8 @@ newNamesForMTy orig_mty = do Scalar $ Sum $ (fmap . fmap) substituteInType ts substituteInType (Array () u shape t) = arrayOf u (substituteInShape shape) (substituteInType $ Scalar t) - substituteInType (Scalar (Arrow als v t1 (RetType dims t2))) = - Scalar $ Arrow als v (substituteInType t1) $ RetType dims $ substituteInType t2 + substituteInType (Scalar (Arrow als v d1 t1 (RetType dims t2))) = + Scalar $ Arrow als v d1 (substituteInType t1) $ RetType dims $ substituteInType t2 substituteInShape (Shape ds) = Shape $ map substituteInDim ds diff --git a/src/Language/Futhark/TypeChecker/Monad.hs b/src/Language/Futhark/TypeChecker/Monad.hs index c3d84c9896..a107a9eefb 100644 --- a/src/Language/Futhark/TypeChecker/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Monad.hs @@ -432,8 +432,8 @@ qualifyTypeVars outer_env orig_except ref_qs = onType (S.fromList orig_except) Record $ M.map (onType except) m onScalar except (Sum m) = Sum $ M.map (map $ onType except) m - onScalar except (Arrow as p t1 (RetType dims t2)) = - Arrow as p (onType except' t1) $ RetType dims (onType except' t2) + onScalar except (Arrow as p d t1 (RetType dims t2)) = + Arrow as p d (onType except' t1) $ RetType dims (onType except' t2) where except' = case p of Named p' -> S.insert p' except diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index da859e46d6..44cc4d25fe 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -211,10 +211,10 @@ checkApplyExp (AppExp (Apply e1 e2 _ loc) _) = do arg <- checkArg e2 (e1', (fname, i)) <- checkApplyExp e1 t <- expType e1' - (t1, rt, argext, exts) <- checkApply loc (fname, i) t arg + (d1, _, rt, argext, exts) <- checkApply loc (fname, i) t arg pure ( AppExp - (Apply e1' (argExp arg) (Info (diet t1, argext)) loc) + (Apply e1' (argExp arg) (Info (d1, argext)) loc) (Info $ AppRes rt exts), (fname, i + 1) ) @@ -347,8 +347,8 @@ checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. - (p1_t, rt, p1_ext, _) <- checkApply loc (Just op', 0) ftype e1_arg - (p2_t, rt', p2_ext, retext) <- checkApply loc (Just op', 1) rt e2_arg + (_, p1_t, rt, p1_ext, _) <- checkApply loc (Just op', 0) ftype e1_arg + (_, p2_t, rt', p2_ext, retext) <- checkApply loc (Just op', 1) rt e2_arg pure $ AppExp @@ -471,8 +471,7 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l bindSpaced [(Term, name)] $ do name' <- checkName Term name loc - let arrow (xp, xt) yt = RetType [] $ Scalar $ Arrow () xp xt yt - RetType _ ftype = foldr (arrow . patternParam) rettype params' + let ftype = funType params' rettype entry = BoundV Local tparams' $ ftype `setAliases` closure' bindF scope = scope @@ -627,12 +626,11 @@ checkExp (Lambda params body rettype_te NoInfo loc) = do -- are parameters, or are used in parameters. inferReturnSizes params' ret = do cur_lvl <- curLevel - let named (Named x, _) = Just x - named (Unnamed, _) = Nothing + let named (Named x, _, _) = Just x + named (Unnamed, _, _) = Nothing param_names = mapMaybe (named . patternParam) params' pos_sizes = - sizeNamesPos . foldFunType (map patternStructType params') $ - RetType [] ret + sizeNamesPos $ foldFunTypeFromParams params' $ RetType [] ret hide k (lvl, _) = lvl >= cur_lvl && k `notElem` param_names && k `S.notMember` pos_sizes @@ -650,9 +648,9 @@ checkExp (OpSection op _ loc) = do checkExp (OpSectionLeft op _ e _ _ loc) = do (op', ftype) <- lookupVar loc op e_arg <- checkArg e - (t1, rt, argext, retext) <- checkApply loc (Just op', 0) ftype e_arg + (_, t1, rt, argext, retext) <- checkApply loc (Just op', 0) ftype e_arg case (ftype, rt) of - (Scalar (Arrow _ m1 _ _), Scalar (Arrow _ m2 t2 rettype)) -> + (Scalar (Arrow _ m1 _ _ _), Scalar (Arrow _ m2 _ t2 rettype)) -> pure $ OpSectionLeft op' @@ -668,12 +666,12 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do (op', ftype) <- lookupVar loc op e_arg <- checkArg e case ftype of - Scalar (Arrow as1 m1 t1 (RetType [] (Scalar (Arrow as2 m2 t2 (RetType dims2 ret))))) -> do - (t2', ret', argext, _) <- + Scalar (Arrow as1 m1 d1 t1 (RetType [] (Scalar (Arrow as2 m2 d2 t2 (RetType dims2 ret))))) -> do + (_, t2', ret', argext, _) <- checkApply loc (Just op', 1) - (Scalar $ Arrow as2 m2 t2 $ RetType [] $ Scalar $ Arrow as1 m1 t1 $ RetType [] ret) + (Scalar $ Arrow as2 m2 d2 t2 $ RetType [] $ Scalar $ Arrow as1 m1 d1 t1 $ RetType [] ret) e_arg pure $ OpSectionRight @@ -690,13 +688,13 @@ checkExp (ProjectSection fields NoInfo loc) = do a <- newTypeVar loc "a" let usage = mkUsage loc "projection at" b <- foldM (flip $ mustHaveField usage) a fields - let ft = Scalar $ Arrow mempty Unnamed (toStruct a) $ RetType [] b + let ft = Scalar $ Arrow mempty Unnamed Observe (toStruct a) $ RetType [] b pure $ ProjectSection fields (Info ft) loc checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice (t, _) <- newArrayType loc "e" $ sliceDims slice' (t', retext) <- sliceShape Nothing slice' t - let ft = Scalar $ Arrow mempty Unnamed t $ RetType retext $ fromStruct t' + let ft = Scalar $ Arrow mempty Unnamed Observe t $ RetType retext $ fromStruct t' pure $ IndexSection slice' (Info ft) loc checkExp (AppExp (DoLoop _ mergepat mergeexp form loopbody loc) _) = do ((sparams, mergepat', mergeexp', form', loopbody'), appres) <- @@ -841,7 +839,7 @@ boundInsideType (Scalar (TypeVar _ _ _ targs)) = foldMap f targs f TypeArgDim {} = mempty boundInsideType (Scalar (Record fs)) = foldMap boundInsideType fs boundInsideType (Scalar (Sum cs)) = foldMap (foldMap boundInsideType) cs -boundInsideType (Scalar (Arrow _ pn t1 (RetType dims t2))) = +boundInsideType (Scalar (Arrow _ pn _ t1 (RetType dims t2))) = pn' <> boundInsideType t1 <> S.fromList dims <> boundInsideType t2 where pn' = case pn of @@ -863,11 +861,11 @@ checkApply :: ApplyOp -> PatType -> Arg -> - TermTypeM (StructType, PatType, Maybe VName, [VName]) + TermTypeM (Diet, StructType, PatType, Maybe VName, [VName]) checkApply loc (fname, _) - (Scalar (Arrow as pname tp1 tp2)) + (Scalar (Arrow as pname d1 tp1 tp2)) (argexp, argtype, dflow, argloc) = onFailure (CheckingApply fname argexp tp1 (toStruct argtype)) $ do expect (mkUsage argloc "use as function argument") (toStruct tp1) (toStruct argtype) @@ -896,14 +894,14 @@ checkApply in zeroOrderType (mkUsage argloc "potential consumption in expression") msg tp1 _ -> pure () - arg_consumed <- consumedByArg argloc argtype' (diet tp1') + arg_consumed <- consumedByArg (locOf argloc) argtype' d1 checkIfConsumable loc $ mconcat arg_consumed occur $ dflow `seqOccurrences` map (`consumption` argloc) arg_consumed -- Unification ignores uniqueness in higher-order arguments, so -- we check for that here. unless (toStructural argtype' `subtypeOf` setUniqueness (toStructural tp1') Nonunique) $ - typeError loc mempty "Consumption/aliasing does not match." + typeError loc mempty "Difference in whether argument is consumed." (argext, parsubst) <- case pname of @@ -923,18 +921,16 @@ checkApply v <- newID "internal_app_result" modify $ \s -> s {stateNames = M.insert v (NameAppRes fname loc) $ stateNames s} let appres = S.singleton $ AliasFree v - let tp2'' = applySubst parsubst $ returnType appres tp2' (diet tp1') argtype' + let tp2'' = applySubst parsubst $ returnType appres tp2' d1 argtype' - pure (tp1', tp2'', argext, ext) + pure (d1, tp1', tp2'', argext, ext) checkApply loc fname tfun@(Scalar TypeVar {}) arg = do tv <- newTypeVar loc "b" -- Change the uniqueness of the argument type because we never want -- to infer that a function is consuming. let argt_nonunique = toStruct (argType arg) `setUniqueness` Nonunique unify (mkUsage loc "use as function") (toStruct tfun) $ - Scalar $ - Arrow mempty Unnamed argt_nonunique $ - RetType [] tv + Scalar (Arrow mempty Unnamed Observe argt_nonunique $ RetType [] tv) tfun' <- normPatType tfun checkApply loc fname tfun' arg checkApply loc (fname, prev_applied) ftype (argexp, _, _, _) = do @@ -962,27 +958,21 @@ checkApply loc (fname, prev_applied) ftype (argexp, _, _, _) = do | prev_applied == 1 = "argument" | otherwise = "arguments" -consumedByArg :: SrcLoc -> PatType -> Diet -> TermTypeM [Aliasing] -consumedByArg loc (Scalar (Record ets)) (RecordDiet ds) = - mconcat . M.elems <$> traverse (uncurry $ consumedByArg loc) (M.intersectionWith (,) ets ds) -consumedByArg loc (Scalar (Sum ets)) (SumDiet ds) = - mconcat <$> traverse (uncurry $ consumedByArg loc) (concat $ M.elems $ M.intersectionWith zip ets ds) -consumedByArg loc (Scalar (Arrow _ _ t1 _)) (FuncDiet d _) - | not $ contravariantArg t1 d = - typeError loc mempty . withIndexLink "consuming-argument" $ - "Non-consuming higher-order parameter passed consuming argument." +aliasParts :: PatType -> [Aliasing] +aliasParts (Scalar (Record ts)) = foldMap aliasParts $ M.elems ts +aliasParts t = [aliases t] + +consumedByArg :: Loc -> PatType -> Diet -> TermTypeM [Aliasing] +consumedByArg loc at Consume = do + let parts = aliasParts at + foldM_ check mempty parts + pure parts where - contravariantArg (Array _ Unique _ _) Observe = - False - contravariantArg (Scalar (TypeVar _ Unique _ _)) Observe = - False - contravariantArg (Scalar (Record ets)) (RecordDiet ds) = - and (M.intersectionWith contravariantArg ets ds) - contravariantArg (Scalar (Arrow _ _ tp (RetType _ tr))) (FuncDiet dp dr) = - contravariantArg tp dp && contravariantArg tr dr - contravariantArg _ _ = - True -consumedByArg _ at Consume = pure [aliases at] + check seen als + | any (`S.member` seen) als = + typeError loc mempty . withIndexLink "self-aliasing-arg" $ + "Argument passed for consuming parameter is self-aliased." + | otherwise = pure $ als <> seen consumedByArg _ _ _ = pure [] -- | Type-check a single expression in isolation. This expression may @@ -1283,8 +1273,8 @@ hiddenParamNames :: [Pat] -> Names hiddenParamNames params = hidden where param_all_names = mconcat $ map patNames params - named (Named x, _) = Just x - named (Unnamed, _) = Nothing + named (Named x, _, _) = Just x + named (Unnamed, _, _) = Nothing param_names = S.fromList $ mapMaybe (named . patternParam) params hidden = param_all_names `S.difference` param_names @@ -1399,7 +1389,7 @@ checkBinding (fname, maybe_retdecl, tparams, params, body, loc) = -- | Extract all the shape names that occur in positive position -- (roughly, left side of an arrow) in a given type. sizeNamesPos :: TypeBase Size als -> S.Set VName -sizeNamesPos (Scalar (Arrow _ _ t1 (RetType _ t2))) = onParam t1 <> sizeNamesPos t2 +sizeNamesPos (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = onParam t1 <> sizeNamesPos t2 where onParam :: TypeBase Size als -> S.Set VName onParam (Scalar Arrow {}) = mempty @@ -1511,7 +1501,7 @@ verifyFunctionParams fname params = where forbidden' = case patternParam p of - (Named v, _) -> forbidden `S.difference` S.singleton v + (Named v, _, _) -> forbidden `S.difference` S.singleton v _ -> forbidden verifyParams _ [] = pure () @@ -1536,8 +1526,8 @@ injectExt ext ret = RetType ext_here $ deeper ret deeper (Scalar (Prim t)) = Scalar $ Prim t deeper (Scalar (Record fs)) = Scalar $ Record $ M.map deeper fs deeper (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map deeper) cs - deeper (Scalar (Arrow als p t1 (RetType t2_ext t2))) = - Scalar $ Arrow als p t1 $ injectExt (ext_there <> t2_ext) t2 + deeper (Scalar (Arrow als p d1 t1 (RetType t2_ext t2))) = + Scalar $ Arrow als p d1 t1 $ injectExt (ext_there <> t2_ext) t2 deeper (Scalar (TypeVar as u tn targs)) = Scalar $ TypeVar as u tn $ map deeperArg targs deeper t@Array {} = t @@ -1571,7 +1561,8 @@ closeOverTypes defname defloc tparams paramts ret substs = do injectExt (retext ++ mapMaybe mkExt (S.toList $ freeInType ret)) ret ) where - t = foldFunType paramts $ RetType [] ret + -- Diet does not matter here. + t = foldFunType (zip (repeat Observe) paramts) $ RetType [] ret to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs visible = typeVars t <> freeInType t diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 1411f30563..fd61dc452d 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -677,9 +677,9 @@ instance MonadTypeChecker TermTypeM where argtype <- newTypeVar loc "t" equalityType usage argtype pure $ - Scalar . Arrow mempty Unnamed argtype . RetType [] $ + Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ Scalar $ - Arrow mempty Unnamed argtype $ + Arrow mempty Unnamed Observe argtype $ RetType [] $ Scalar $ Prim Bool @@ -687,7 +687,7 @@ instance MonadTypeChecker TermTypeM where argtype <- newTypeVar loc "t" mustBeOneOf ts usage argtype let (pts', rt') = instOverloaded argtype pts rt - arrow xt yt = Scalar $ Arrow mempty Unnamed xt $ RetType [] yt + arrow xt yt = Scalar $ Arrow mempty Unnamed Observe xt $ RetType [] yt pure $ fromStruct $ foldr arrow rt' pts' observe $ Ident name (Info t) loc @@ -1002,7 +1002,7 @@ initialTermScope = initialVtable = M.fromList $ mapMaybe addIntrinsicF $ M.toList intrinsics prim = Scalar . Prim - arrow x y = Scalar $ Arrow mempty Unnamed x y + arrow x y = Scalar $ Arrow mempty Unnamed Observe x y addIntrinsicF (name, IntrinsicMonoFun pts t) = Just (name, BoundV Global [] $ arrow pts' $ RetType [] $ prim t) @@ -1015,15 +1015,8 @@ initialTermScope = addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = Just ( name, - BoundV Global tvs $ - fromStruct $ - Scalar $ - Arrow mempty Unnamed pts' rt + BoundV Global tvs $ fromStruct $ foldFunType pts rt ) - where - pts' = case pts of - [pt] -> pt - _ -> Scalar $ tupleRecord pts addIntrinsicF (name, IntrinsicEquality) = Just (name, EqualityF) addIntrinsicF _ = Nothing diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 62f39b056f..e1ad04143c 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -72,7 +72,7 @@ mustBeExplicitInBinding bind_t = M.fromList $ zip (S.toList $ freeInType ret) $ repeat True - in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty ts + in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty $ map snd ts where onType uses t = uses <> mustBeExplicitAux t -- Left-biased union. @@ -102,8 +102,8 @@ returnType appres (Scalar (TypeVar als Nonunique t targs)) d arg = Scalar $ TypeVar (appres <> als <> arg_als) Unique t targs where arg_als = aliases $ maskAliases arg d -returnType _ (Scalar (Arrow old_als v t1 (RetType dims t2))) d arg = - Scalar $ Arrow als v (t1 `setAliases` mempty) $ RetType dims $ t2 `setAliases` als +returnType _ (Scalar (Arrow old_als v pd t1 (RetType dims t2))) d arg = + Scalar $ Arrow als v pd (t1 `setAliases` mempty) $ RetType dims $ t2 `setAliases` als where -- Make sure to propagate the aliases of an existing closure. als = old_als <> aliases (maskAliases arg d) @@ -119,12 +119,6 @@ maskAliases :: TypeBase shape as maskAliases t Consume = t `setAliases` mempty maskAliases t Observe = t -maskAliases (Scalar (Record ets)) (RecordDiet ds) = - Scalar $ Record $ M.intersectionWith maskAliases ets ds -maskAliases (Scalar (Sum ets)) (SumDiet ds) = - Scalar $ Sum $ M.intersectionWith (zipWith maskAliases) ets ds -maskAliases t FuncDiet {} = t -maskAliases _ _ = error "Invalid arguments passed to maskAliases." -- | The two types are assumed to be structurally equal, but not -- necessarily regarding sizes. Combines aliases. @@ -140,9 +134,9 @@ addAliasesFromType (Scalar (Record ts1)) (Scalar (Record ts2)) sort (M.keys ts1) == sort (M.keys ts2) = Scalar $ Record $ M.intersectionWith addAliasesFromType ts1 ts2 addAliasesFromType - (Scalar (Arrow als1 mn1 pt1 (RetType dims1 rt1))) - (Scalar (Arrow als2 _ _ (RetType _ rt2))) = - Scalar (Arrow (als1 <> als2) mn1 pt1 (RetType dims1 rt1')) + (Scalar (Arrow als1 mn1 d1 pt1 (RetType dims1 rt1))) + (Scalar (Arrow als2 _ _ _ (RetType _ rt2))) = + Scalar (Arrow (als1 <> als2) mn1 d1 pt1 (RetType dims1 rt1')) where rt1' = addAliasesFromType rt1 rt2 addAliasesFromType (Scalar (Sum cs1)) (Scalar (Sum cs2)) @@ -203,9 +197,9 @@ unifyScalarTypes uf (Record ts1) (Record ts2) (M.intersectionWith (,) ts1 ts2) unifyScalarTypes uf - (Arrow as1 mn1 t1 (RetType dims1 t1')) - (Arrow as2 _ t2 (RetType _ t2')) = - Arrow (as1 <> as2) mn1 + (Arrow as1 mn1 d1 t1 (RetType dims1 t1')) + (Arrow as2 _ _ t2 (RetType _ t2')) = + Arrow (as1 <> as2) mn1 d1 <$> unifyTypesU (flip uf) t1 t2 <*> (RetType dims1 <$> unifyTypesU uf t1' t2') unifyScalarTypes uf (Sum cs1) (Sum cs2) @@ -339,7 +333,7 @@ evalTypeExp (TEArrow (Just v) t1 t2 loc) = do pure ( TEArrow (Just v') t1' t2' loc, svars1 ++ dims1 ++ svars2, - RetType [] $ Scalar $ Arrow mempty (Named v') st1 (RetType dims2 st2), + RetType [] $ Scalar $ Arrow mempty (Named v') (diet st1) st1 (RetType dims2 st2), Lifted ) -- @@ -349,7 +343,9 @@ evalTypeExp (TEArrow Nothing t1 t2 loc) = do pure ( TEArrow Nothing t1' t2' loc, svars1 ++ dims1 ++ svars2, - RetType [] $ Scalar $ Arrow mempty Unnamed st1 $ RetType dims2 st2, + RetType [] . Scalar $ + Arrow mempty Unnamed (diet st1) (st1 `setUniqueness` Nonunique) $ + RetType dims2 st2, Lifted ) -- @@ -741,8 +737,8 @@ substTypesRet lookupSubst ot = pure $ Scalar $ TypeVar als u v targs' onType (Scalar (Record ts)) = Scalar . Record <$> traverse onType ts - onType (Scalar (Arrow als v t1 t2)) = - Scalar <$> (Arrow als v <$> onType t1 <*> onRetType t2) + onType (Scalar (Arrow als v d t1 t2)) = + Scalar <$> (Arrow als v d <$> onType t1 <*> onRetType t2) onType (Scalar (Sum ts)) = Scalar . Sum <$> traverse (traverse onType) ts diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 7f87740a54..fc6a3c1445 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -482,15 +482,15 @@ unifyWith onDims usage = subunify False (_, Scalar (TypeVar _ _ (QualName [] v2) [])) | Just lvl <- nonrigid v2 -> link (not ord) v2 lvl t1' - ( Scalar (Arrow _ p1 a1 (RetType b1_dims b1)), - Scalar (Arrow _ p2 a2 (RetType b2_dims b2)) + ( Scalar (Arrow _ p1 d1 a1 (RetType b1_dims b1)), + Scalar (Arrow _ p2 d2 a2 (RetType b2_dims b2)) ) - | uncurry (<) $ swap ord (uniqueness a1) (uniqueness a2) -> do + | uncurry (<) $ swap ord d1 d2 -> do unifyError usage mempty bcs . withIndexLink "unify-consuming-param" $ - "Parameter types" - indent 2 (pretty a1) + "Parameters" + indent 2 (pretty d1 <> pretty a1) "and" - indent 2 (pretty a2) + indent 2 (pretty d2 <> pretty a2) "are incompatible regarding consuming their arguments." | otherwise -> do -- Introduce the existentials as size variables so they diff --git a/tests/accs/intrinsics.fut b/tests/accs/intrinsics.fut index 5b893d5cc1..4f6af17c5f 100644 --- a/tests/accs/intrinsics.fut +++ b/tests/accs/intrinsics.fut @@ -8,7 +8,7 @@ def scatter_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.scatter_stream (dest, f, bs) :> *[k]a + intrinsics.scatter_stream dest f bs :> *[k]a def reduce_by_index_stream [k] 'a 'b (dest: *[k]a) @@ -17,7 +17,7 @@ def reduce_by_index_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.hist_stream (dest, op, ne, f, bs) :> *[k]a + intrinsics.hist_stream dest op ne f bs :> *[k]a def write [n] 't (acc : *acc ([n]t)) (i: i64) (v: t) : *acc ([n]t) = - intrinsics.acc_write (acc, i, v) + intrinsics.acc_write acc i v diff --git a/tests/implicit_method.fut b/tests/implicit_method.fut index cac918909b..bf05be0f65 100644 --- a/tests/implicit_method.fut +++ b/tests/implicit_method.fut @@ -44,7 +44,7 @@ -- } -def tridagSeq [n][m] (a: [n]f32,b: *[m]f32,c: [m]f32,y: *[m]f32 ): *[m]f32 = +def tridagSeq [n][m] (a: [n]f32) (b: *[m]f32) (c: [m]f32) (y: *[m]f32 ): *[m]f32 = let (y,b) = loop ((y, b)) for i < n-1 do let i = i + 1 @@ -71,7 +71,8 @@ def implicitMethod [n][m] (myD: [m][3]f32, myDD: [m][3]f32, , dtInv - 0.5*(mu*d[1] + 0.5*var*dd[1]) , 0.0 - 0.5*(mu*d[2] + 0.5*var*dd[2]))) (zip4 (mu_row) (var_row) myD myDD)) - in tridagSeq( a, copy b, c, copy u_row )) (zip3 myMu myVar u) + in tridagSeq a (copy b) c (copy u_row)) + (zip3 myMu myVar u) def main [m][n] (myD: [m][3]f32) (myDD: [m][3]f32) (myMu: [n][m]f32) (myVar: [n][m]f32) diff --git a/tests/issue-1774.fut b/tests/issue1774.fut similarity index 100% rename from tests/issue-1774.fut rename to tests/issue1774.fut diff --git a/tests/migration/intrinsics.fut b/tests/migration/intrinsics.fut index f6ad2c3b98..f10742dcb0 100644 --- a/tests/migration/intrinsics.fut +++ b/tests/migration/intrinsics.fut @@ -1,20 +1,20 @@ def flat_index_2d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) : [n1][n2]a = - intrinsics.flat_index_2d(as, offset, n1, s1, n2, s2) :> [n1][n2]a + intrinsics.flat_index_2d as offset n1 s1 n2 s2 :> [n1][n2]a def flat_update_2d [n][k][l] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (asss: [k][l]a) : *[n]a = - intrinsics.flat_update_2d(as, offset, s1, s2, asss) + intrinsics.flat_update_2d as offset s1 s2 asss def flat_index_3d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) : [n1][n2][n3]a = - intrinsics.flat_index_3d(as, offset, n1, s1, n2, s2, n3, s3) :> [n1][n2][n3]a + intrinsics.flat_index_3d as offset n1 s1 n2 s2 n3 s3 :> [n1][n2][n3]a def flat_update_3d [n][k][l][p] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (asss: [k][l][p]a) : *[n]a = - intrinsics.flat_update_3d(as, offset, s1, s2, s3, asss) + intrinsics.flat_update_3d as offset s1 s2 s3 asss def flat_index_4d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) (n4: i64) (s4: i64) : [n1][n2][n3][n4]a = - intrinsics.flat_index_4d(as, offset, n1, s1, n2, s2, n3, s3, n4, s4) :> [n1][n2][n3][n4]a + intrinsics.flat_index_4d as offset n1 s1 n2 s2 n3 s3 n4 s4 :> [n1][n2][n3][n4]a def flat_update_4d [n][k][l][p][q] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (s4: i64) (asss: [k][l][p][q]a) : *[n]a = - intrinsics.flat_update_4d(as, offset, s1, s2, s3, s4, asss) + intrinsics.flat_update_4d as offset s1 s2 s3 s4 asss type~ acc 't = intrinsics.acc t @@ -23,7 +23,7 @@ def scatter_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.scatter_stream (dest, f, bs) :> *[k]a + intrinsics.scatter_stream dest f bs :> *[k]a def reduce_by_index_stream [k] 'a 'b (dest: *[k]a) @@ -32,7 +32,7 @@ def reduce_by_index_stream [k] 'a 'b (f: *acc ([k]a) -> b -> *acc ([k]a)) (bs: []b) : *[k]a = - intrinsics.hist_stream (dest, op, ne, f, bs) :> *[k]a + intrinsics.hist_stream dest op ne f bs :> *[k]a def write [n] 't (acc : *acc ([n]t)) (i: i64) (v: t) : *acc ([n]t) = - intrinsics.acc_write (acc, i, v) \ No newline at end of file + intrinsics.acc_write acc i v diff --git a/tests/slice-lmads/intrinsics.fut b/tests/slice-lmads/intrinsics.fut index ac32880267..6ac2b46804 100644 --- a/tests/slice-lmads/intrinsics.fut +++ b/tests/slice-lmads/intrinsics.fut @@ -1,17 +1,17 @@ def flat_index_2d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) : [n1][n2]a = - intrinsics.flat_index_2d(as, offset, n1, s1, n2, s2) :> [n1][n2]a + intrinsics.flat_index_2d as offset n1 s1 n2 s2 :> [n1][n2]a def flat_update_2d [n][k][l] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (asss: [k][l]a) : *[n]a = - intrinsics.flat_update_2d(as, offset, s1, s2, asss) + intrinsics.flat_update_2d as offset s1 s2 asss def flat_index_3d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) : [n1][n2][n3]a = - intrinsics.flat_index_3d(as, offset, n1, s1, n2, s2, n3, s3) :> [n1][n2][n3]a + intrinsics.flat_index_3d as offset n1 s1 n2 s2 n3 s3 :> [n1][n2][n3]a def flat_update_3d [n][k][l][p] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (asss: [k][l][p]a) : *[n]a = - intrinsics.flat_update_3d(as, offset, s1, s2, s3, asss) + intrinsics.flat_update_3d as offset s1 s2 s3 asss def flat_index_4d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) (n4: i64) (s4: i64) : [n1][n2][n3][n4]a = - intrinsics.flat_index_4d(as, offset, n1, s1, n2, s2, n3, s3, n4, s4) :> [n1][n2][n3][n4]a + intrinsics.flat_index_4d as offset n1 s1 n2 s2 n3 s3 n4 s4 :> [n1][n2][n3][n4]a def flat_update_4d [n][k][l][p][q] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (s4: i64) (asss: [k][l][p][q]a) : *[n]a = - intrinsics.flat_update_4d(as, offset, s1, s2, s3, s4, asss) + intrinsics.flat_update_4d as offset s1 s2 s3 s4 asss diff --git a/tests/uniqueness/uniqueness-error23.fut b/tests/uniqueness/uniqueness-error23.fut index a951e993c4..8cd37999e9 100644 --- a/tests/uniqueness/uniqueness-error23.fut +++ b/tests/uniqueness/uniqueness-error23.fut @@ -1,13 +1,14 @@ -- == --- error: .*consumed.* +-- error: self-aliased def g(ar: *[]i64, a: *[][]i64): i64 = ar[0] def f(ar: *[]i64, a: *[][]i64): i64 = - g(a[0], a) -- Should be a type error, as both are supposed to be unique + g(a[0], a) -- Should be a type error, as both are supposed to be + -- unique yet they alias each other. def main(n: i64): i64 = - let a = copy(replicate n (iota n)) + let a = replicate n (iota n) let ar = copy(a[0]) in f(ar, a) diff --git a/unittests/Language/Futhark/SyntaxTests.hs b/unittests/Language/Futhark/SyntaxTests.hs index 3cb8b9a242..6440bdb61c 100644 --- a/unittests/Language/Futhark/SyntaxTests.hs +++ b/unittests/Language/Futhark/SyntaxTests.hs @@ -153,11 +153,18 @@ pScalarType :: Parser (ScalarTypeBase Size ()) pScalarType = choice [try pFun, pScalarNonFun] where pFun = - uncurry (Arrow ()) <$> pParam <* lexeme "->" <*> pStructRetType + pParam <* lexeme "->" <*> pStructRetType pParam = - choice [try pNamedParam, (Unnamed,) <$> pNonFunType] - pNamedParam = - parens $ (,) <$> (Named <$> pVName) <* lexeme ":" <*> pStructType + choice + [ try pNamedParam, + do + t <- pNonFunType + pure $ Arrow () Unnamed (diet t) t + ] + pNamedParam = parens $ do + v <- pVName <* lexeme ":" + t <- pStructType + pure $ Arrow () (Named v) (diet t) t pStructRetType :: Parser StructRetType pStructRetType =