From c846c5fe7f4ca15b8a51af347cf86fc2bedcf68a Mon Sep 17 00:00:00 2001 From: catvayor <92800608+catvayor@users.noreply.github.com> Date: Mon, 24 Apr 2023 13:48:57 +0200 Subject: [PATCH] Support arbitrary size expressions Futhark is basically dependently typed now. Some work remains to be done, but this is surprisingly useful already. Closes #1659. --- docs/error-index.rst | 91 +- docs/glossary.rst | 20 +- futhark-benchmarks | 2 +- prelude/array.fut | 24 +- src/Futhark/CLI/Dev.hs | 2 +- src/Futhark/Doc/Generator.hs | 7 +- src/Futhark/IR/Parse.hs | 10 +- src/Futhark/Internalise/Defunctionalise.hs | 50 +- src/Futhark/Internalise/Exps.hs | 2 +- src/Futhark/Internalise/LiftLambdas.hs | 16 +- src/Futhark/Internalise/Monomorphise.hs | 725 ++++++++++--- src/Futhark/Internalise/TypesValues.hs | 7 +- src/Language/Futhark.hs | 80 -- src/Language/Futhark/FreeVars.hs | 39 +- src/Language/Futhark/Interpreter.hs | 293 ++++-- src/Language/Futhark/Interpreter/Values.hs | 16 +- src/Language/Futhark/Pretty.hs | 20 +- src/Language/Futhark/Prop.hs | 973 +++++++++++------- src/Language/Futhark/Syntax.hs | 46 +- src/Language/Futhark/Traversals.hs | 32 +- src/Language/Futhark/TypeChecker.hs | 22 +- src/Language/Futhark/TypeChecker/Modules.hs | 41 +- src/Language/Futhark/TypeChecker/Monad.hs | 34 +- src/Language/Futhark/TypeChecker/Terms.hs | 536 +++++++--- .../Futhark/TypeChecker/Terms/DoLoop.hs | 85 +- .../Futhark/TypeChecker/Terms/Monad.hs | 78 +- src/Language/Futhark/TypeChecker/Terms/Pat.hs | 12 +- src/Language/Futhark/TypeChecker/Types.hs | 122 +-- src/Language/Futhark/TypeChecker/Unify.hs | 203 ++-- tests/badentry10.fut | 7 + tests/badentry9.fut | 5 + tests/binding-warn0.fut | 6 + tests/binding-warn1.fut | 7 + tests/entryexpr.fut | 5 + tests/flattening/LoopInvReshape.fut | 2 +- tests/flattening/map-nested-free.fut | 9 + tests/fusion/Vers2.0/hindrReshape0.fut | 2 +- .../fusion/fuse-across-reshape-transpose.fut | 3 +- tests/fusion/fuse-across-reshape1.fut | 3 +- tests/fusion/fuse-across-reshape2.fut | 5 +- tests/fusion/fuse-across-reshape3.fut | 2 +- tests/higher-order-functions/alias0.fut | 2 +- .../higher-order-functions/shape-params6.fut | 9 + .../higher-order-functions/shape-params7.fut | 11 + tests/issue1112.fut | 2 +- tests/issue1237.fut | 38 +- tests/issue1239.fut | 2 +- tests/issue1424.fut | 2 +- tests/issue1435.fut | 1 - tests/issue245.fut | 2 +- tests/issue246.fut | 2 +- tests/issue869.fut | 2 +- tests/loops/pow2reduce.fut | 17 + .../coalescing/concat/neg0.fut | 2 +- .../coalescing/concat/pos1.fut | 2 +- .../coalescing/lud/lud.fut | 2 +- tests/modules/ascription-error7.fut | 8 + tests/modules/ascription15.fut | 7 + tests/modules/sizeparams8.fut | 5 + tests/modules/sizes2.fut | 2 +- tests/operator/size-section0.fut | 4 + tests/operator/size-section1.fut | 4 + tests/operator/size-section2.fut | 5 + tests/operator/size-section3.fut | 4 + tests/operator/size-section4.fut | 4 + tests/replicate3.fut | 2 +- tests/reshape1.fut | 4 +- tests/rosettacode/md5.fut | 6 +- tests/scatter/nw.fut | 41 +- tests/shapes/error4.fut | 2 +- tests/shapes/error5.fut | 2 +- tests/shapes/existential-argument.fut | 9 + tests/shapes/field-in-size.fut | 4 + tests/shapes/funshape3.fut | 2 +- tests/shapes/funshape5.fut | 2 +- tests/shapes/funshape6.fut | 2 +- tests/shapes/implicit-shape-use.fut | 15 +- tests/shapes/modules1.fut | 2 +- tests/shapes/opaque0.fut | 21 + tests/shapes/range0.fut | 4 +- tests/shapes/range1.fut | 4 +- tests/shapes/range2.fut | 2 +- tests/shapes/range3.fut | 8 + tests/shapes/size-inference2.fut | 2 +- tests/shapes/size-inference7.fut | 6 +- tests/size-expr-for-in.fut | 5 + tests/slice-lmads/flat.fut | 2 +- tests/slice-lmads/lud.fut | 2 +- tests/slice-lmads/small.fut | 2 +- tests/slice-lmads/small_4d.fut | 2 +- tests/sumtypes/existential-match.fut | 8 + tests/types/badsquare-lam.fut | 3 +- tests/types/badsquare.fut | 2 +- tests/types/ext2.fut | 2 +- tests/types/metasizes.fut | 52 + tests/types/sizeparams10.fut | 11 + tests/types/sizeparams11.fut | 10 + tests/uniqueness/uniqueness-error62.fut | 5 + tests/uniqueness/uniqueness-warn0.fut | 7 + tests/uniqueness/uniqueness-warn1.fut | 9 + tests/uniqueness/uniqueness59.fut | 4 + tests/uniqueness/uniqueness60.fut | 5 + tests_repl/issue1347.fut | 2 +- tests_repl/issue1347.in | 2 +- unittests/Language/Futhark/SyntaxTests.hs | 8 +- .../Futhark/TypeChecker/TypesTests.hs | 3 +- 106 files changed, 2752 insertions(+), 1317 deletions(-) create mode 100644 tests/badentry10.fut create mode 100644 tests/badentry9.fut create mode 100644 tests/binding-warn0.fut create mode 100644 tests/binding-warn1.fut create mode 100644 tests/entryexpr.fut create mode 100644 tests/flattening/map-nested-free.fut create mode 100644 tests/higher-order-functions/shape-params6.fut create mode 100644 tests/higher-order-functions/shape-params7.fut create mode 100644 tests/loops/pow2reduce.fut create mode 100644 tests/modules/ascription-error7.fut create mode 100644 tests/modules/ascription15.fut create mode 100644 tests/modules/sizeparams8.fut create mode 100644 tests/operator/size-section0.fut create mode 100644 tests/operator/size-section1.fut create mode 100644 tests/operator/size-section2.fut create mode 100644 tests/operator/size-section3.fut create mode 100644 tests/operator/size-section4.fut create mode 100644 tests/shapes/existential-argument.fut create mode 100644 tests/shapes/field-in-size.fut create mode 100644 tests/shapes/opaque0.fut create mode 100644 tests/shapes/range3.fut create mode 100644 tests/size-expr-for-in.fut create mode 100644 tests/sumtypes/existential-match.fut create mode 100644 tests/types/metasizes.fut create mode 100644 tests/types/sizeparams10.fut create mode 100644 tests/types/sizeparams11.fut create mode 100644 tests/uniqueness/uniqueness-error62.fut create mode 100644 tests/uniqueness/uniqueness-warn0.fut create mode 100644 tests/uniqueness/uniqueness-warn1.fut create mode 100644 tests/uniqueness/uniqueness59.fut create mode 100644 tests/uniqueness/uniqueness60.fut diff --git a/docs/error-index.rst b/docs/error-index.rst index a68f24fe49..e144f82ed3 100644 --- a/docs/error-index.rst +++ b/docs/error-index.rst @@ -267,6 +267,44 @@ not any free variables. Use ``copy`` to fix this: def f () = copy x +.. _size-expression-bind: + +"Size expression with binding is replaced by unknown size." +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To illustrate this error, consider the following program + +.. code-block:: futhark + + def main (xs: *[]i64) = + let a = iota (let n = 10 in n+n) + in ... + +Intuitively, the type of ``a`` should be ``[let n = 10 in n+n]i32``, +but this puts a binding into a size expression, which is invalid. +Therefore, the type checker invents an :term:`unknown size` +variable, say ``l``, and assigns ``a`` the type ``[l]i32``. + +.. _size-expression-consume: + +"Size expression with consumption is replaced by unknown size." +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To illustrate this error, consider the following program + +.. code-block:: futhark + + def consume (xs: *[]i64): i64 = xs[0] + + def main (xs: *[]i64) = + let a = iota (consume xs) + in ... + +Intuitively, the type of ``a`` should be ``[consume ys]i32``, but this +puts a consumption of the array ``ys`` into a size expression, which +is invalid. Therefore, the type checker invents an :term:`unknown +size` variable, say ``l``, and assigns ``a`` the type ``[l]i32``. + .. _inaccessible-size: "Parameter *x* refers to size *y* which will not be accessible to the caller @@ -385,7 +423,7 @@ into a separate ``let``-binding preceding the problematic expressions. This error occurs when you define a function that can never be applied, as it requires an input of a specific size, and that size is -not known. Somewhat contrived example: +an :term:`unknown size`. Somewhat contrived example: .. code-block:: futhark @@ -471,8 +509,9 @@ use of either pipelining or composition. "Existential size *n* not used as array size" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This error occurs for type expressions that use explicit existential -quantification in an incorrect way, such as the following examples: +This error occurs for type expressions that bind an existential size +for which there is no :term:`constructive use`, such as in the +following examples: .. code-block:: futhark @@ -481,12 +520,12 @@ quantification in an incorrect way, such as the following examples: ?[n].bool -> [n]bool When we use existential quantification, we are required to use the -size within its scope, *and* it must not exclusively be used to the -right of function arrow. +size constructively within its scope, *in particular* it must not be +exclusively as the parameter or return type of a function. To understand the motivation behind this rule, consider that when we -use an existential quantifier we are saying that there is *some size*, -it just cannot be known statically, but must be read from some value +use an existential quantifier we are saying that there is *some size*. +The size is not known statically, but must be read from some value (i.e. array) at runtime. In the first example above, the existential size ``n`` is not used at all, so the actual value cannot be determined at runtime. In the second example, while an array @@ -943,3 +982,41 @@ ordinary parameter: .. code-block:: futhark entry f (n: i64) : [0][n]i32 = [] + +.. _nonconstructive-entry: + +"Entry point size parameter [n] only used non-constructively." +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This error occurs for programs such as the following:: + +.. code-block:: futhark + + entry main [x] (A: [x+1]i32) = ... + +The size parameter ``[x]`` is only used in an size expression ``x+1``, +rather than directly as an array size. This is allowed for ordinary +functions, but not for entry points. The reason is that entry points +are not subject to ordinary type inference, as they are called from +the external world, meaning that the value of the size parameter +``[x]`` will have to be determined from the size of the array ``A``. +This is in principle not a problem for simple sizes like ``x+1``, as +it is obvious that ``x == length A - 1``, but in the general case it +would require computing function inverses that might not exist. For +this reason, entry points require that all size parameters are used +:term:`constructively`. + +As a workaround, you can rewrite the entry point as follows: + +.. code-block:: futhark + + entry main [n] (A: [n]i32) = + let x = n-1 + let A = A :> [x+1]i32 + ... + +Or by passing the ``x`` explicitly: + +.. code-block:: futhark + + entry main (x: i64) (A: [x+1]i32) = ... diff --git a/docs/glossary.rst b/docs/glossary.rst index f40e4d49db..b7b6a25686 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -73,11 +73,13 @@ documentation and in compiler output. * ``[n]bool`` * ``([n]bool, bool -> [n]bool)`` + * ``([n]bool, [n+1]bool)`` The following do not: * ``[n+1]bool`` * ``bool -> [n]bool`` + * ``[n]bool -> bool`` Consumption @@ -296,6 +298,20 @@ documentation and in compiler output. The symbolic size of an array dimension or :term:`abstract type`. + Size expression + + An expression that occurs as the size of an array or size + argument. For example, in the type ``[x+2]i32``, ``x+2`` is a + size expression. Size expressions can occur syntactically in + source code, or due to parameter substitution when applying a + :term:`size-dependent function`. + + Size-dependent function + + A function where the size of the result depends on the values of + the parameters. The function ``iota`` is perhaps the simplest + example. + Size types Size-dependent types @@ -385,7 +401,9 @@ documentation and in compiler output. Unknown size A size produced by invoking a function whose result type contains - an existentially quantified size, such as ``filter``. + an existentially quantified size, such as ``filter``, or because + the original :term:`size expression` involves variables that have + gone out of scope. Value diff --git a/futhark-benchmarks b/futhark-benchmarks index 50ca9af348..5e63f33d25 160000 --- a/futhark-benchmarks +++ b/futhark-benchmarks @@ -1 +1 @@ -Subproject commit 50ca9af3485e6debfd3f41c40f3dcd3b4d38f7f2 +Subproject commit 5e63f33d25f15071c4141710b9acff94f5232734 diff --git a/prelude/array.fut b/prelude/array.fut index b58046a4d7..14ff6efcc6 100644 --- a/prelude/array.fut +++ b/prelude/array.fut @@ -28,12 +28,12 @@ def last [n] 't (x: [n]t) = x[n-1] -- | Everything but the first element of the array. -- -- **Complexity:** O(1). -def tail [n] 't (x: [n]t) = x[1:] +def tail [n] 't (x: [n]t): [n-1]t = x[1:] -- | Everything but the last element of the array. -- -- **Complexity:** O(1). -def init [n] 't (x: [n]t) = x[0:n-1] +def init [n] 't (x: [n]t): [n-1]t = x[0:n-1] -- | Take some number of elements from the head of the array. -- @@ -43,12 +43,12 @@ def take [n] 't (i: i64) (x: [n]t): [i]t = x[0:i] -- | Remove some number of elements from the head of the array. -- -- **Complexity:** O(1). -def drop [n] 't (i: i64) (x: [n]t) = x[i:] +def drop [n] 't (i: i64) (x: [n]t): [n-i]t = x[i:] -- | Split an array at a given position. -- -- **Complexity:** O(1). -def split [n] 't (i: i64) (xs: [n]t): ([i]t, []t) = +def split [n] 't (i: i64) (xs: [n]t): ([i]t, [n-i]t) = (xs[0:i], xs[i:]) -- | Return the elements of the array in reverse order. @@ -62,10 +62,10 @@ def reverse [n] 't (x: [n]t): [n]t = x[::-1] -- **Work:** O(n). -- -- **Span:** O(1). -def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = intrinsics.concat xs ys +def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = intrinsics.concat xs ys -- | An old-fashioned way of saying `++`. -def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = xs ++ ys +def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys -- | Concatenation where the result has a predetermined size. If the -- provided size is wrong, the function will fail with a run-time @@ -123,7 +123,7 @@ def copy 't (a: t): *t = -- | Combines the outer two dimensions of an array. -- -- **Complexity:** O(1). -def flatten [n][m] 't (xs: [n][m]t): []t = +def flatten [n][m] 't (xs: [n][m]t): [n*m]t = intrinsics.flatten xs -- | Like `flatten`@term, but where the final size is known. Fails at @@ -132,25 +132,25 @@ def flatten_to [n][m] 't (l: i64) (xs: [n][m]t): [l]t = flatten xs :> [l]t -- | Like `flatten`, but on the outer three dimensions of an array. -def flatten_3d [n][m][l] 't (xs: [n][m][l]t): []t = +def flatten_3d [n][m][l] 't (xs: [n][m][l]t): [n*m*l]t = flatten (flatten xs) -- | Like `flatten`, but on the outer four dimensions of an array. -def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): []t = +def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): [n*m*l*k]t = flatten (flatten_3d xs) -- | Splits the outer dimension of an array in two. -- -- **Complexity:** O(1). -def unflatten [p] 't (n: i64) (m: i64) (xs: [p]t): [n][m]t = +def unflatten 't (n: i64) (m: i64) (xs: [n*m]t): [n][m]t = intrinsics.unflatten n m xs :> [n][m]t -- | Like `unflatten`, but produces three dimensions. -def unflatten_3d [p] 't (n: i64) (m: i64) (l: i64) (xs: [p]t): [n][m][l]t = +def unflatten_3d 't (n: i64) (m: i64) (l: i64) (xs: [n*m*l]t): [n][m][l]t = unflatten n m (unflatten (n*m) l xs) -- | Like `unflatten`, but produces four dimensions. -def unflatten_4d [p] 't (n: i64) (m: i64) (l: i64) (k: i64) (xs: [p]t): [n][m][l][k]t = +def unflatten_4d 't (n: i64) (m: i64) (l: i64) (k: i64) (xs: [n*m*l*k]t): [n][m][l][k]t = unflatten n m (unflatten_3d (n*m) l k xs) -- | Transpose an array. diff --git a/src/Futhark/CLI/Dev.hs b/src/Futhark/CLI/Dev.hs index e02d9fabb5..77a106fff3 100644 --- a/src/Futhark/CLI/Dev.hs +++ b/src/Futhark/CLI/Dev.hs @@ -76,7 +76,7 @@ data FutharkPipeline Defunctorise | -- | Defunctorise and monomorphise. Monomorphise - | -- | Defunctorise, monomorphise, and lambda-lift. + | -- | Defunctorise, monomorphise and lambda-lift. LiftLambdas | -- | Defunctorise, monomorphise, lambda-lift, and defunctionalise. Defunctionalise diff --git a/src/Futhark/Doc/Generator.hs b/src/Futhark/Doc/Generator.hs index 7f307e0eb2..936d91d6e2 100644 --- a/src/Futhark/Doc/Generator.hs +++ b/src/Futhark/Doc/Generator.hs @@ -541,8 +541,8 @@ prettyShape (Shape ds) = mconcat <$> mapM dimDeclHtml ds typeArgHtml :: TypeArg Size -> DocM Html -typeArgHtml (TypeArgDim d _) = dimDeclHtml d -typeArgHtml (TypeArgType t _) = typeHtml t +typeArgHtml (TypeArgDim d) = dimDeclHtml d +typeArgHtml (TypeArgType t) = typeHtml t modParamHtml :: [ModParamBase Info VName] -> DocM Html modParamHtml [] = pure mempty @@ -706,8 +706,7 @@ relativise dest src = concat (replicate (length (splitPath src) - 1) "../") ++ dest dimDeclHtml :: Size -> DocM Html -dimDeclHtml (NamedSize v) = brackets <$> qualNameHtml v -dimDeclHtml (ConstSize n) = pure $ brackets $ toHtml (show n) +dimDeclHtml (SizeExpr e) = pure $ brackets $ toHtml $ prettyString e dimDeclHtml AnySize {} = pure $ brackets mempty dimExpHtml :: SizeExp Info VName -> DocM Html diff --git a/src/Futhark/IR/Parse.hs b/src/Futhark/IR/Parse.hs index aa3328d897..bf3e14161d 100644 --- a/src/Futhark/IR/Parse.hs +++ b/src/Futhark/IR/Parse.hs @@ -12,7 +12,7 @@ where import Data.Char (isAlpha) import Data.Functor -import Data.List (zipWith4) +import Data.List (singleton, zipWith4) import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Set qualified as S @@ -54,13 +54,13 @@ pName = pVName :: Parser VName pVName = lexeme $ do (s, tag) <- - satisfy constituent + choice [exprBox, singleton <$> satisfy constituent] `manyTill_` try pTag "variable name" - pure $ VName (nameFromString s) tag + pure $ VName (nameFromString $ concat s) tag where - pTag = - "_" *> L.decimal <* notFollowedBy (satisfy constituent) + pTag = "_" *> L.decimal <* notFollowedBy (satisfy constituent) + exprBox = ("<{" <>) . (<> "}>") <$> (chunk "<{" *> manyTill anySingle (chunk "}>")) pBool :: Parser Bool pBool = choice [keyword "true" $> True, keyword "false" $> False] diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs index 730e97bae0..1678b22ff2 100644 --- a/src/Futhark/Internalise/Defunctionalise.hs +++ b/src/Futhark/Internalise/Defunctionalise.hs @@ -72,11 +72,11 @@ replaceTypeSizes :: TypeBase Size als replaceTypeSizes substs = first onDim where - onDim (NamedSize v) = + onDim (SizeExpr (Var v typ loc)) = case M.lookup (qualLeaf v) substs of - Just (SubstNamed v') -> NamedSize v' - Just (SubstConst d) -> ConstSize d - Nothing -> NamedSize v + Just (SubstNamed v') -> SizeExpr (Var v' typ loc) + Just (SubstConst d) -> sizeFromInteger (toInteger d) loc + Nothing -> SizeExpr (Var v typ loc) onDim d = d replaceStaticValSizes :: @@ -268,15 +268,15 @@ arraySizes (Scalar (Sum cs)) = foldMap (foldMap arraySizes) cs arraySizes (Scalar (TypeVar _ _ _ targs)) = mconcat $ map f targs where - f (TypeArgDim (NamedSize d) _) = S.singleton $ qualLeaf d + f (TypeArgDim (SizeExpr (Var d _ _))) = S.singleton $ qualLeaf d f TypeArgDim {} = mempty - f (TypeArgType t _) = arraySizes t + f (TypeArgType t) = arraySizes t arraySizes (Scalar Prim {}) = mempty arraySizes (Array _ _ shape t) = arraySizes (Scalar t) <> foldMap dimName (shapeDims shape) where dimName :: Size -> S.Set VName - dimName (NamedSize qn) = S.singleton $ qualLeaf qn + dimName (SizeExpr (Var qn _ _)) = S.singleton $ qualLeaf qn dimName _ = mempty patternArraySizes :: Pat -> S.Set VName @@ -294,14 +294,14 @@ dimMapping :: M.Map VName SizeSubst dimMapping t1 t2 = execState (matchDims f t1 t2) mempty where - f bound d1 (NamedSize d2) + f bound d1 (SizeExpr (Var d2 _ _)) | qualLeaf d2 `elem` bound = pure d1 - f _ (NamedSize d1) (NamedSize d2) = do + f _ (SizeExpr (Var d1 typ loc)) (SizeExpr (Var d2 _ _)) = do modify $ M.insert (qualLeaf d1) $ SubstNamed d2 - pure $ NamedSize d1 - f _ (NamedSize d1) (ConstSize d2) = do - modify $ M.insert (qualLeaf d1) $ SubstConst d2 - pure $ NamedSize d1 + pure $ SizeExpr $ Var d1 typ loc + f _ (SizeExpr (Var d1 typ loc)) (SizeExpr (IntLit d2 _ _)) = do + modify $ M.insert (qualLeaf d1) $ SubstConst $ fromInteger d2 + pure $ SizeExpr $ Var d1 typ loc f _ d _ = pure d dimMapping' :: @@ -328,7 +328,7 @@ sizesToRename (RecordSV fs) = sizesToRename (SumSV _ svs _) = foldMap sizesToRename svs sizesToRename (LambdaSV param _ _ _) = - freeInPat param + fvVars (freeInPat param) <> S.map identName (S.filter couldBeSize $ patIdents param) where couldBeSize ident = @@ -714,7 +714,7 @@ etaExpand e_t e = do ext' <- mapM newName $ retDims ret let extsubst = M.fromList . zip (retDims ret) $ - map (SizeSubst . NamedSize . qualName) ext' + map (ExpSubst . flip sizeVar mempty . qualName) ext' ret' = applySubst (`M.lookup` extsubst) ret e' = mkApply @@ -762,7 +762,7 @@ defuncLet :: DefM ([VName], [Pat], Exp, StaticVal) defuncLet dims ps@(pat : pats) body (RetType ret_dims rettype) | patternOrderZero pat = do - let bound_by_pat = (`S.member` freeInPat pat) + let bound_by_pat = (`S.member` fvVars (freeInPat pat)) -- Take care to not include more size parameters than necessary. (pat_dims, rest_dims) = partition bound_by_pat dims env = envFromPat pat <> envFromDimNames pat_dims @@ -797,24 +797,24 @@ sizesForAll bound_sizes params = do tv = identityMapper {mapOnPatType = bitraverse onDim pure} onDim (AnySize (Just v)) = do modify $ S.insert v - pure $ NamedSize $ qualName v + pure $ sizeFromName (qualName v) mempty onDim (AnySize Nothing) = do v <- lift $ newVName "size" modify $ S.insert v - pure $ NamedSize $ qualName v - onDim (NamedSize d) = do + pure $ sizeFromName (qualName v) mempty + onDim (SizeExpr (Var d typ loc)) = do unless (qualLeaf d `S.member` bound) $ modify $ S.insert $ qualLeaf d - pure $ NamedSize d + pure $ SizeExpr $ Var d typ loc onDim d = pure d unRetType :: StructRetType -> StructType unRetType (RetType [] t) = t unRetType (RetType ext t) = first onDim t where - onDim (NamedSize d) | qualLeaf d `elem` ext = AnySize Nothing + onDim (SizeExpr (Var d _ _)) | qualLeaf d `elem` ext = AnySize Nothing onDim d = d defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal) @@ -1041,7 +1041,7 @@ liftValDec fname (RetType ret_dims ret) dims pats body = addValBind dec mkExt v | not $ v `S.member` bound_here = Just v mkExt _ = Nothing - rettype_st = RetType (mapMaybe mkExt (S.toList (freeInType ret)) ++ ret_dims) ret + rettype_st = RetType (mapMaybe mkExt (M.keys $ unFV $ freeInType ret) ++ ret_dims) ret dec = ValBind @@ -1157,7 +1157,7 @@ matchPatSV (Id vn (Info t) _) sv = else dim_env <> M.singleton vn (Binding Nothing sv) where dim_env = - M.fromList $ map (,i64) $ S.toList $ freeInType t + M.map (const i64) $ unFV $ freeInType t i64 = Binding Nothing $ Dynamic $ Scalar $ Prim $ Signed Int64 matchPatSV (Wildcard _ _) _ = pure mempty matchPatSV (PatAscription pat _ _) sv = matchPatSV pat sv @@ -1283,7 +1283,7 @@ defuncValBind valbind@(ValBind _ name retdecl (Info (RetType ret_dims rettype)) -- applications of lifted functions, we don't properly update -- the types in the return type annotation. combineTypeShapes rettype $ first (anyDimIfNotBound bound_sizes) $ toStruct $ typeOf body' - ret_dims' = filter (`S.member` freeInType rettype') ret_dims + ret_dims' = filter (`M.member` (unFV $ freeInType rettype')) ret_dims (missing_dims, params'') <- sizesForAll bound_sizes params' pure @@ -1305,7 +1305,7 @@ defuncValBind valbind@(ValBind _ name retdecl (Info (RetType ret_dims rettype)) sv ) where - anyDimIfNotBound bound_sizes (NamedSize v) + anyDimIfNotBound bound_sizes (SizeExpr (Var v _ _)) | qualLeaf v `S.notMember` bound_sizes = AnySize $ Just $ qualLeaf v anyDimIfNotBound _ d = d diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 1c349dc150..77501e50bb 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -343,7 +343,7 @@ internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = -- application. One caveat is that we need to replace any -- existential sizes, too (with zeroes, because they don't -- matter). - let subst = zip ext $ repeat $ E.SizeSubst $ E.ConstSize 0 + let subst = zip ext $ repeat $ E.ExpSubst $ E.sizeInteger 0 mempty et' = E.applySubst (`lookup` subst) et internaliseExp desc (E.Hole (Info et') loc) (FunctionName qfname, args) -> do diff --git a/src/Futhark/Internalise/LiftLambdas.hs b/src/Futhark/Internalise/LiftLambdas.hs index 44f3282709..f68088527b 100644 --- a/src/Futhark/Internalise/LiftLambdas.hs +++ b/src/Futhark/Internalise/LiftLambdas.hs @@ -68,10 +68,6 @@ existentials e = m = identityMapper {mapOnExp = \e' -> modify (<> existentials e') >> pure e'} in execState (astMap m e) here -freeSizes :: S.Set VName -> FV -freeSizes vs = - FV $ M.fromList $ zip (S.toList vs) $ repeat $ Scalar $ Prim $ Signed Int64 - liftFunction :: VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp liftFunction fname tparams params (RetType dims ret) funbody = do -- Find free variables @@ -87,16 +83,18 @@ liftFunction fname tparams params (RetType dims ret) funbody = do sizes_in_free = foldMap freeInType $ M.elems $ unFV immediate_free sizes = - freeSizes $ - sizes_in_free - <> foldMap freeInPat params - <> freeInType ret + FV $ + M.map (const (Scalar $ Prim $ Signed Int64)) $ + unFV $ + sizes_in_free + <> foldMap freeInPat params + <> freeInType ret in M.toList $ unFV $ immediate_free <> (sizes `freeWithout` bound) -- Those parameters that correspond to sizes must come first. sizes_in_types = foldMap freeInType (ret : map snd free ++ map patternStructType params) - isSize (v, _) = v `S.member` sizes_in_types + isSize (v, _) = v `M.member` unFV sizes_in_types (free_dims, free_nondims) = partition isSize free free_params = diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs index 08ad1e325f..e542549ea1 100644 --- a/src/Futhark/Internalise/Monomorphise.hs +++ b/src/Futhark/Internalise/Monomorphise.hs @@ -20,14 +20,19 @@ -- -- * Rewrite BinOp nodes to Apply nodes. -- +-- * Replace all size expressions by constants or variables, +-- complex expressions replaced by variables are calculated in +-- let binding or replaced by size parameters if in argument. +-- -- Note that these changes are unfortunately not visible in the AST -- representation. module Futhark.Internalise.Monomorphise (transformProg) where import Control.Monad +import Control.Monad.Identity import Control.Monad.RWS (MonadReader (..), MonadWriter (..), RWST, asks, runRWST) import Control.Monad.State -import Control.Monad.Writer (runWriterT) +import Control.Monad.Writer (Writer, runWriter, runWriterT) import Data.Bifunctor import Data.Bitraversable import Data.Foldable @@ -38,6 +43,7 @@ import Data.Maybe import Data.Sequence qualified as Seq import Data.Set qualified as S import Futhark.MonadFreshNames +import Futhark.Util (nubOrd) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Semantic (TypeBinding (..)) @@ -73,19 +79,86 @@ type RecordReplacements = M.Map VName RecordReplacement type RecordReplacement = M.Map Name (VName, PatType) +-- | To deduplicate size expressions, we want a looser notation of +-- equality than the strict syntactical equality provided by the Eq +-- instance on Exp. This newtype wrapper provides such a looser +-- notion of equality. +newtype ReplacedExp = ReplacedExp {unReplaced :: Exp} + deriving (Show) + +instance Pretty ReplacedExp where + pretty (ReplacedExp e) = pretty e + +instance Eq ReplacedExp where + ReplacedExp e1 == ReplacedExp e2 + | Just es <- similarExps e1 e2 = + all (uncurry (==) . bimap ReplacedExp ReplacedExp) es + _ == _ = False + +type ExpReplacements = [(ReplacedExp, VName)] + +canCalculate :: S.Set VName -> ExpReplacements -> ExpReplacements +canCalculate scope mapping = do + filter + ( (`S.isSubsetOf` scope) + . S.filter notIntrisic + . fvVars + . freeInExp + . unReplaced + . fst + ) + mapping + where + notIntrisic vn = baseTag vn > maxIntrinsicTag + +-- Replace some expressions by a parameter. +expReplace :: ExpReplacements -> Exp -> Exp +expReplace mapping e + | Just vn <- lookup (ReplacedExp e) mapping = + Var (qualName vn) (Info $ typeOf e) (srclocOf e) +expReplace mapping e = runIdentity $ astMap mapper e + where + mapper = identityMapper {mapOnExp = pure . expReplace mapping} + +-- Construct an Assert expression that checks that the names (values) +-- in the mapping have the same value as the expression they +-- represent. This is injected into entry points, where we cannot +-- otherwise trust the input. XXX: the error message generated from +-- this is not great; we should rework it eventually. +entryAssert :: ExpReplacements -> Exp -> Exp +entryAssert [] body = body +entryAssert (x : xs) body = + Assert (foldl logAnd (cmpExp x) $ map cmpExp xs) body errmsg (srclocOf body) + where + errmsg = Info "entry point arguments have invalid sizes." + bool = Scalar $ Prim Bool + opt = foldFunType [(Observe, bool), (Observe, bool)] $ RetType [] bool + andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty + eqop = Var (qualName (intrinsicVar "==")) (Info opt) mempty + logAnd x' y = + mkApply andop [(Observe, Nothing, x'), (Observe, Nothing, y)] $ + AppRes bool [] + cmpExp (ReplacedExp x', y) = + mkApply eqop [(Observe, Nothing, x'), (Observe, Nothing, y')] $ + AppRes bool [] + where + y' = Var (qualName y) (Info i64) mempty + -- Monomorphization environment mapping names of polymorphic functions -- to a representation of their corresponding function bindings. data Env = Env { envPolyBindings :: M.Map VName PolyBinding, envTypeBindings :: M.Map VName TypeBinding, - envRecordReplacements :: RecordReplacements + envRecordReplacements :: RecordReplacements, + envScope :: S.Set VName, + envParametrized :: ExpReplacements } instance Semigroup Env where - Env tb1 pb1 rr1 <> Env tb2 pb2 rr2 = Env (tb1 <> tb2) (pb1 <> pb2) (rr1 <> rr2) + Env tb1 pb1 rr1 sc1 pr1 <> Env tb2 pb2 rr2 sc2 pr2 = Env (tb1 <> tb2) (pb1 <> pb2) (rr1 <> rr2) (sc1 <> sc2) (pr1 <> pr2) instance Monoid Env where - mempty = Env mempty mempty mempty + mempty = Env mempty mempty mempty mempty mempty localEnv :: Env -> MonoM a -> MonoM a localEnv env = local (env <>) @@ -101,13 +174,27 @@ withRecordReplacements rr = localEnv mempty {envRecordReplacements = rr} replaceRecordReplacements :: RecordReplacements -> MonoM a -> MonoM a replaceRecordReplacements rr = local $ \env -> env {envRecordReplacements = rr} +isolateNormalisation :: MonoM a -> MonoM a +isolateNormalisation m = do + prevRepl <- get + put mempty + ret <- local (\env -> env {envScope = mempty, envParametrized = mempty}) m + put prevRepl + pure ret + +withArgs :: S.Set VName -> MonoM a -> MonoM a +withArgs args = localEnv $ mempty {envScope = args} + +withParams :: ExpReplacements -> MonoM a -> MonoM a +withParams params = localEnv $ mempty {envParametrized = params} + -- The monomorphization monad. newtype MonoM a = MonoM ( RWST Env (Seq.Seq (VName, ValBind)) - VNameSource + (ExpReplacements, VNameSource) (State Lifts) a ) @@ -116,14 +203,21 @@ newtype MonoM a Applicative, Monad, MonadReader Env, - MonadWriter (Seq.Seq (VName, ValBind)), - MonadFreshNames + MonadWriter (Seq.Seq (VName, ValBind)) ) +instance MonadFreshNames MonoM where + getNameSource = MonoM $ gets snd + putNameSource = MonoM . modify . second . const + +instance MonadState ExpReplacements MonoM where + get = MonoM $ gets fst + put = MonoM . modify . first . const + runMonoM :: VNameSource -> MonoM a -> ((a, Seq.Seq (VName, ValBind)), VNameSource) runMonoM src (MonoM m) = ((a, defs), src') where - (a, src', defs) = evalState (runRWST m mempty src) mempty + (a, (_, src'), defs) = evalState (runRWST m mempty (mempty, src)) mempty lookupFun :: VName -> MonoM (Maybe PolyBinding) lookupFun vn = do @@ -135,8 +229,100 @@ lookupFun vn = do lookupRecordReplacement :: VName -> MonoM (Maybe RecordReplacement) lookupRecordReplacement v = asks $ M.lookup v . envRecordReplacements +askScope :: MonoM (S.Set VName) +askScope = do + scope <- asks envScope + scope' <- asks $ S.union scope . M.keysSet . envPolyBindings + S.union scope' . S.fromList . map (fst . snd) <$> getLifts + +-- | Asks the introduced variables in a set of argument, +-- that is arguments not currently in scope. +askIntros :: S.Set VName -> MonoM (S.Set VName) +askIntros argset = + (S.filter notIntrisic argset `S.difference`) <$> askScope + where + notIntrisic vn = baseTag vn > maxIntrinsicTag + +-- | Gets and removes expressions that could not be calculated when +-- the arguments set will be unscoped. +-- This should be called without argset in scope, for good detection of intros. +parametrizing :: S.Set VName -> MonoM ExpReplacements +parametrizing argset = do + intros <- askIntros argset + (params, nxtBind) <- gets $ partition (not . S.disjoint intros . fvVars . freeInExp . unReplaced . fst) + put nxtBind + pure params + +calculateDims :: Exp -> ExpReplacements -> MonoM Exp +calculateDims body repl = + foldCalc top_repl $ expReplace top_repl body + where + ---- topological sorting + exp_idxs = zip (map fst repl) [0 ..] + -- list of strict sub-expressions of e + subExps e + | Just e' <- stripExp e = subExps e' + | otherwise = astMap mapper e `execState` mempty + where + mapOnExp e' + | Just e'' <- stripExp e' = mapOnExp e'' + | otherwise = do + modify (ReplacedExp e' :) + astMap mapper e' + mapper = identityMapper {mapOnExp} + -- @a `depends_of` (b,i)@ returns @Just i@ + -- iff b appear in a as an expression + depends_of a (b, i) = + if b `elem` subExps (unReplaced a) + then Just i + else Nothing + -- graph of dependencies, represented with adjacency list + depends_graph = + map (\(e, _) -> mapMaybe (depends_of e) exp_idxs) exp_idxs + + sorting i = do + done <- gets $ (!! i) . snd + unless done $ do + mapM_ sorting $ depends_graph !! i + modify $ bimap (repl !! i :) (\status -> map (\j -> i == j || status !! j) [0 .. length status]) + top_repl = + fst $ execState (mapM_ (sorting . snd) exp_idxs) (mempty, map (const False) exp_idxs) + + ---- Calculus insertion + foldCalc [] body' = pure body' + foldCalc ((dim, vn) : repls) body' = do + reName <- newName vn + let expr = expReplace repls $ unReplaced dim + subst vn' = + if vn' == vn + then Just $ ExpSubst $ sizeVar (qualName reName) mempty + else Nothing + appRes = case body' of + (AppExp _ (Info (AppRes ty ext))) -> Info $ AppRes (applySubst subst ty) (reName : ext) + e -> Info $ AppRes (applySubst subst $ typeOf e) [reName] + foldCalc repls $ + AppExp + ( LetPat + [] + (Id vn (Info i64) (srclocOf expr)) + expr + body' + mempty + ) + appRes + +unscoping :: S.Set VName -> Exp -> MonoM Exp +unscoping argset body = do + localDims <- parametrizing argset + scope <- S.union argset <$> askScope + calculateDims body $ canCalculate scope localDims + +scoping :: S.Set VName -> MonoM Exp -> MonoM Exp +scoping argset m = + withArgs argset m >>= unscoping argset + -- Given instantiated type of function, produce size arguments. -type InferSizeArgs = StructType -> [Exp] +type InferSizeArgs = StructType -> MonoM [Exp] data MonoSize = -- | The integer encodes an equivalence class, so we can keep @@ -166,7 +352,7 @@ type MonoType = TypeBase MonoSize () monoType :: TypeBase Size als -> MonoType monoType = (`evalState` (0, mempty)) . traverseDims onDim . toStruct where - onDim bound _ (NamedSize d) + onDim bound _ (SizeExpr (Var d _ _)) -- A locally bound size. | qualLeaf d `S.member` bound = pure $ MonoAnon $ qualLeaf d onDim _ _ d = do @@ -195,6 +381,31 @@ addLifted fname il liftf = lookupLifted :: VName -> MonoType -> MonoM (Maybe (VName, InferSizeArgs)) lookupLifted fname t = lookup (fname, t) <$> getLifts +-- | Creates a new expression replacement if needed, this always produces normalised sizes. +-- (e.g. single variable or constant) +replaceExp :: Exp -> MonoM Exp +replaceExp e = + case maybeNormalisedSize e of + Just e' -> pure e' + Nothing -> do + let e' = ReplacedExp e + prev <- gets $ lookup e' + prev_param <- asks $ lookup e' . envParametrized + case (prev_param, prev) of + (Just vn, _) -> pure $ sizeVar (qualName vn) (srclocOf e) + (Nothing, Just vn) -> pure $ sizeVar (qualName vn) (srclocOf e) + (Nothing, Nothing) -> do + vn <- newNameFromString $ "d<{" ++ prettyString e ++ "}>" + modify ((e', vn) :) + pure $ sizeVar (qualName vn) (srclocOf e) + where + -- Avoid replacing of some 'already normalised' sizes that are just surounded by some parentheses. + maybeNormalisedSize e' + | Just e'' <- stripExp e' = maybeNormalisedSize e'' + maybeNormalisedSize (Var qn _ loc) = Just $ sizeVar qn loc + maybeNormalisedSize (IntLit v _ loc) = Just $ IntLit v (Info i64) loc + maybeNormalisedSize _ = Nothing + transformFName :: SrcLoc -> QualName VName -> StructType -> MonoM Exp transformFName loc fname t | baseTag (qualLeaf fname) <= maxIntrinsicTag = pure $ var fname @@ -206,7 +417,7 @@ transformFName loc fname t case (maybe_fname, maybe_funbind) of -- The function has already been monomorphised. (Just (fname', infer), _) -> - pure $ applySizeArgs fname' t' $ infer t' + applySizeArgs fname' t' <$> infer t' -- An intrinsic function. (Nothing, Nothing) -> pure $ var fname -- A polymorphic function. @@ -214,7 +425,7 @@ transformFName loc fname t (fname', infer, funbind') <- monomorphiseBinding False funbind mono_t tell $ Seq.singleton (qualLeaf fname, funbind') addLifted (qualLeaf fname) mono_t (fname', infer) - pure $ applySizeArgs fname' t' $ infer t' + applySizeArgs fname' t' <$> infer t' where var fname' = Var fname' (Info (fromStruct t)) loc @@ -243,20 +454,102 @@ transformFName loc fname t ) size_args +transformTypeSizes :: TypeBase Size as -> MonoM (TypeBase Size as) +transformTypeSizes typ = + case typ of + Scalar scalar -> Scalar <$> transformScalarSizes scalar + Array as u shape scalar -> Array as u <$> mapM onDim shape <*> transformScalarSizes scalar + where + transformScalarSizes (Record fs) = + Record <$> traverse transformTypeSizes fs + transformScalarSizes (Sum cs) = + Sum <$> (traverse . traverse) transformTypeSizes cs + transformScalarSizes (Arrow as argName d argT retT) = + Arrow as argName d <$> transformTypeSizes argT <*> transformRetTypeSizes argset retT + where + argset = + fvVars (freeInType argT) + <> case argName of + Unnamed -> mempty + Named vn -> S.singleton vn + transformScalarSizes (TypeVar as uniq qn args) = + TypeVar as uniq qn <$> mapM onArg args + where + onArg (TypeArgDim dim) = TypeArgDim <$> onDim dim + onArg (TypeArgType ty) = TypeArgType <$> transformTypeSizes ty + transformScalarSizes ty = pure ty + + onDim (SizeExpr e) = SizeExpr <$> (replaceExp =<< transformExp e) + onDim (AnySize v) = pure $ AnySize v + +transformRetTypeSizes :: S.Set VName -> RetTypeBase Size as -> MonoM (RetTypeBase Size as) +transformRetTypeSizes argset (RetType dims ty) = do + ty' <- withArgs argset $ transformTypeSizes ty + rl <- parametrizing argset + let dims' = dims <> map snd rl + pure $ RetType dims' ty' + +transformTypeExp :: TypeExp Info VName -> MonoM (TypeExp Info VName) +transformTypeExp te@TEVar {} = pure te +transformTypeExp (TEParens te loc) = + TEParens <$> transformTypeExp te <*> pure loc +transformTypeExp (TETuple tes loc) = + TETuple <$> mapM transformTypeExp tes <*> pure loc +transformTypeExp (TERecord fs loc) = + TERecord <$> mapM (traverse transformTypeExp) fs <*> pure loc +transformTypeExp (TEArray size te loc) = + TEArray <$> transformSizeExp size <*> transformTypeExp te <*> pure loc + where + transformSizeExp (SizeExp e loc') = + SizeExp <$> (replaceExp =<< transformExp e) <*> pure loc' + transformSizeExp (SizeExpAny loc') = + pure $ SizeExpAny loc' +transformTypeExp (TEUnique te loc) = + TEUnique <$> transformTypeExp te <*> pure loc +transformTypeExp (TEApply te args loc) = + TEApply <$> transformTypeExp te <*> transformTypeArg args <*> pure loc + where + transformTypeArg (TypeArgExpSize size) = + TypeArgExpSize <$> transformSizeExp size + transformTypeArg (TypeArgExpType arg) = + TypeArgExpType <$> transformTypeExp arg + transformSizeExp (SizeExp e loc') = + SizeExp <$> (replaceExp =<< transformExp e) <*> pure loc' + transformSizeExp (SizeExpAny loc') = + pure $ SizeExpAny loc' +transformTypeExp (TEArrow aname ta tr loc) = do + tr' <- case aname of + Just vn -> do + let argset = S.singleton vn + ret <- withArgs argset $ transformTypeExp tr + dims <- parametrizing argset + if null dims + then pure ret + else pure $ TEDim (map snd dims) ret mempty + Nothing -> transformTypeExp tr + TEArrow aname <$> transformTypeExp ta <*> pure tr' <*> pure loc +transformTypeExp (TESum cs loc) = + TESum <$> traverse (traverse (traverse transformTypeExp)) cs <*> pure loc +transformTypeExp (TEDim dims te loc) = + TEDim dims <$> transformTypeExp te <*> pure loc + -- This carries out record replacements in the alias information of a type. -transformType :: TypeBase dim Aliasing -> MonoM (TypeBase dim Aliasing) +-- +-- It also transforms any size expressions. +transformType :: PatType -> MonoM PatType transformType t = do rrs <- asks envRecordReplacements let replace (AliasBound v) | Just d <- M.lookup v rrs = S.fromList $ map (AliasBound . fst) $ M.elems d replace x = S.singleton x + t' <- transformTypeSizes t -- As an attempt at an optimisation, only transform the aliases if -- they refer to a variable we have record-replaced. pure $ if any ((`M.member` rrs) . aliasVar) $ aliases t - then second (mconcat . map replace . S.toList) t - else t + then second (mconcat . map replace . S.toList) t' + else t' sizesForPat :: MonadFreshNames m => Pat -> m ([VName], Pat) sizesForPat pat = do @@ -267,7 +560,7 @@ sizesForPat pat = do onDim (AnySize _) = do v <- lift $ newVName "size" modify (v :) - pure $ NamedSize $ qualName v + pure $ sizeFromName (qualName v) mempty onDim d = pure d transformAppRes :: AppRes -> MonoM AppRes @@ -281,16 +574,19 @@ transformAppExp (Range e1 me incl loc) res = do incl' <- mapM transformExp incl pure $ AppExp (Range e1' me' incl' loc) (Info res) transformAppExp (Coerce e tp loc) res = - AppExp <$> (Coerce <$> transformExp e <*> pure tp <*> pure loc) <*> pure (Info res) -transformAppExp (LetPat sizes pat e1 e2 loc) res = do - (pat', rr) <- transformPat pat - AppExp - <$> ( LetPat sizes pat' - <$> transformExp e1 - <*> withRecordReplacements rr (transformExp e2) - <*> pure loc - ) - <*> pure (Info res) + AppExp <$> (Coerce <$> transformExp e <*> transformTypeExp tp <*> pure loc) <*> pure (Info res) +transformAppExp (LetPat sizes pat e body loc) res = do + e' <- transformExp e + let dimArgs = S.fromList (map sizeName sizes) + implicitDims <- withArgs dimArgs $ askIntros $ fvVars $ freeInPat pat + let dimArgs' = dimArgs <> implicitDims + letArgs = patNames pat + argset = dimArgs' `S.union` letArgs + (pat', rr) <- withArgs dimArgs' $ transformPat pat + params <- parametrizing dimArgs' + let sizes' = sizes <> map (`SizeBinder` mempty) (map snd params <> S.toList implicitDims) + body' <- withRecordReplacements rr $ withParams params $ scoping argset $ transformExp body + pure $ AppExp (LetPat sizes' pat' e' body' loc) (Info res) transformAppExp (LetFun fname (tparams, params, retdecl, Info ret, body) e loc) res | not $ null tparams = do -- Retrieve the lifted monomorphic function bindings that are produced, @@ -299,7 +595,7 @@ transformAppExp (LetFun fname (tparams, params, retdecl, Info ret, body) e loc) rr <- asks envRecordReplacements let funbind = PolyBinding rr (fname, tparams, params, ret, body, mempty, loc) pass $ do - (e', bs) <- listen $ extendEnv fname funbind $ transformExp e + (e', bs) <- listen $ extendEnv fname funbind $ scoping (S.singleton fname) $ transformExp e -- Do not remember this one for next time we monomorphise this -- function. modifyLifts $ filter ((/= fname) . fst . fst) @@ -308,7 +604,10 @@ transformAppExp (LetFun fname (tparams, params, retdecl, Info ret, body) e loc) | otherwise = do body' <- transformExp body AppExp - <$> (LetFun fname (tparams, params, retdecl, Info ret, body') <$> transformExp e <*> pure loc) + <$> ( LetFun fname (tparams, params, retdecl, Info ret, body') + <$> scoping (S.singleton fname) (transformExp e) + <*> pure loc + ) <*> pure (Info res) transformAppExp (If e1 e2 e3 loc) res = AppExp <$> (If <$> transformExp e1 <*> transformExp e2 <*> transformExp e3 <*> pure loc) <*> pure (Info res) @@ -319,20 +618,36 @@ transformAppExp (Apply fe args _) res = <*> pure res where onArg (Info (d, ext), e) = (d,ext,) <$> transformExp e -transformAppExp (DoLoop sparams pat e1 form e3 loc) res = do +transformAppExp (DoLoop sparams pat e1 form body loc) res = do e1' <- transformExp e1 - form' <- case form of - For ident e2 -> For ident <$> transformExp e2 - ForIn pat2 e2 -> ForIn pat2 <$> transformExp e2 - While e2 -> While <$> transformExp e2 - e3' <- transformExp e3 + + let dimArgs = S.fromList sparams + (pat', rr) <- withArgs dimArgs $ transformPat pat + params <- parametrizing dimArgs + let sparams' = sparams <> map snd params + mergeArgs = dimArgs `S.union` patNames pat + + (form', rr', formArgs) <- case form of + For ident e2 -> (,mempty,S.singleton $ identName ident) . For ident <$> transformExp e2 + ForIn pat2 e2 -> do + (pat2', rr') <- transformPat pat2 + (,rr',patNames pat2) . ForIn pat2' <$> transformExp e2 + While e2 -> + fmap ((,mempty,mempty) . While) $ + withRecordReplacements rr $ + withParams params $ + scoping mergeArgs $ + transformExp e2 + let argset = mergeArgs `S.union` formArgs + + body' <- withRecordReplacements (rr <> rr') $ withParams params $ scoping argset $ transformExp body -- Maybe monomorphisation introduced new arrays to the loop, and -- maybe they have AnySize sizes. This is not allowed. Invent some -- sizes for them. - (pat_sizes, pat') <- sizesForPat pat - pure $ AppExp (DoLoop (sparams ++ pat_sizes) pat' e1' form' e3' loc) (Info res) + (pat_sizes, pat'') <- sizesForPat pat' + pure $ AppExp (DoLoop (sparams' ++ pat_sizes) pat'' e1' form' body' loc) (Info res) transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) (AppRes ret ext) = do - fname' <- transformFName loc fname $ toStruct t + fname' <- transformFName loc fname =<< transformTypeSizes (toStruct t) e1' <- transformExp e1 e2' <- transformExp e2 if orderZero (typeOf e1') && orderZero (typeOf e2') @@ -377,18 +692,40 @@ transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) (AppRes ret ex Id x (Info $ fromStruct argtype) mempty ) transformAppExp (LetWith id1 id2 idxs e1 body loc) res = do + id1' <- transformIdent id1 + id2' <- transformIdent id2 idxs' <- mapM transformDimIndex idxs e1' <- transformExp e1 - body' <- transformExp body - pure $ AppExp (LetWith id1 id2 idxs' e1' body' loc) (Info res) + body' <- scoping (S.singleton $ identName id1') $ transformExp body + pure $ AppExp (LetWith id1' id2' idxs' e1' body' loc) (Info res) + where + transformIdent (Ident v t vloc) = + Ident v <$> traverse transformType t <*> pure vloc transformAppExp (Index e0 idxs loc) res = AppExp <$> (Index <$> transformExp e0 <*> mapM transformDimIndex idxs <*> pure loc) <*> pure (Info res) -transformAppExp (Match e cs loc) res = - AppExp - <$> (Match <$> transformExp e <*> mapM transformCase cs <*> pure loc) - <*> pure (Info res) +transformAppExp (Match e cs loc) res = do + implicitDims <- askIntros $ fvVars $ freeInType $ typeOf e + e' <- transformExp e + cs' <- mapM (transformCase implicitDims) cs + if S.null implicitDims + then pure $ AppExp (Match e' cs' loc) (Info res) + else do + tmpVar <- newNameFromString "matched_variable" + pure $ + AppExp + ( LetPat + (map (`SizeBinder` mempty) $ S.toList implicitDims) + (Id tmpVar (Info $ typeOf e') mempty) + e' + ( AppExp + (Match (Var (qualName tmpVar) (Info $ typeOf e') mempty) cs' loc) + (Info res) + ) + mempty + ) + (Info res) -- Monomorphization of expressions. transformExp :: Exp -> MonoM Exp @@ -416,8 +753,7 @@ transformExp (RecordLit fs loc) = loc transformExp (ArrayLit es t loc) = ArrayLit <$> mapM transformExp es <*> traverse transformType t <*> pure loc -transformExp (AppExp e res) = do - noticeDims $ appResType $ unInfo res +transformExp (AppExp e res) = transformAppExp e =<< transformAppRes (unInfo res) transformExp (Var fname (Info t) loc) = do maybe_fs <- lookupRecordReplacement $ qualLeaf fname @@ -440,16 +776,24 @@ transformExp (Negate e loc) = transformExp (Not e loc) = Not <$> transformExp e <*> pure loc transformExp (Lambda params e0 decl tp loc) = do - e0' <- transformExp e0 - pure $ Lambda params e0' decl tp loc + let patArgs = foldMap patNames params + dimArgs <- withArgs patArgs $ askIntros (foldMap (fvVars . freeInPat) params) + let argset = dimArgs `S.union` patArgs + (params', rrs) <- mapAndUnzipM transformPat params + paramed <- parametrizing argset + withRecordReplacements (mconcat rrs) $ + Lambda params' + <$> withParams paramed (scoping argset $ transformExp e0) + <*> pure decl + <*> traverse (traverse transformRetType) tp + <*> pure loc transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc) = do let (Info (xp, xtype, xargext), Info (yp, ytype)) = arg - fname' <- transformFName loc fname $ toStruct t e' <- transformExp e desugarBinOpSection - fname' + fname (Just e') Nothing t @@ -459,10 +803,9 @@ transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc loc transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do let (Info (xp, xtype), Info (yp, ytype, yargext)) = arg - fname' <- transformFName loc fname $ toStruct t e' <- transformExp e desugarBinOpSection - fname' + fname Nothing (Just e') t @@ -470,22 +813,24 @@ transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do (yp, ytype, yargext) (rettype, []) loc -transformExp (ProjectSection fields (Info t) loc) = - desugarProjectSection fields t loc +transformExp (ProjectSection fields (Info t) loc) = do + t' <- transformType t + desugarProjectSection fields t' loc transformExp (IndexSection idxs (Info t) loc) = do idxs' <- mapM transformDimIndex idxs desugarIndexSection idxs' t loc transformExp (Project n e tp loc) = do + tp' <- traverse transformType tp maybe_fs <- case e of Var qn _ _ -> lookupRecordReplacement (qualLeaf qn) _ -> pure Nothing case maybe_fs of Just m | Just (v, _) <- M.lookup n m -> - pure $ Var (qualName v) tp loc + pure $ Var (qualName v) tp' loc _ -> do e' <- transformExp e - pure $ Project n e' tp loc + pure $ Project n e' tp' loc transformExp (Update e1 idxs e2 loc) = Update <$> transformExp e1 @@ -497,19 +842,19 @@ transformExp (RecordUpdate e1 fs e2 t loc) = <$> transformExp e1 <*> pure fs <*> transformExp e2 - <*> pure t + <*> traverse transformType t <*> pure loc transformExp (Assert e1 e2 desc loc) = Assert <$> transformExp e1 <*> transformExp e2 <*> pure desc <*> pure loc transformExp (Constr name all_es t loc) = - Constr name <$> mapM transformExp all_es <*> pure t <*> pure loc + Constr name <$> mapM transformExp all_es <*> traverse transformType t <*> pure loc transformExp (Attr info e loc) = Attr info <$> transformExp e <*> pure loc -transformCase :: Case -> MonoM Case -transformCase (CasePat p e loc) = do +transformCase :: S.Set VName -> Case -> MonoM Case +transformCase implicitDims (CasePat p e loc) = do (p', rr) <- transformPat p - CasePat p' <$> withRecordReplacements rr (transformExp e) <*> pure loc + CasePat p' <$> withRecordReplacements rr (scoping (patNames p `S.union` implicitDims) $ transformExp e) <*> pure loc transformDimIndex :: DimIndexBase Info VName -> MonoM (DimIndexBase Info VName) transformDimIndex (DimFix e) = DimFix <$> transformExp e @@ -520,7 +865,7 @@ transformDimIndex (DimSlice me1 me2 me3) = -- Transform an operator section into a lambda. desugarBinOpSection :: - Exp -> + QualName VName -> Maybe Exp -> Maybe Exp -> PatType -> @@ -529,27 +874,28 @@ desugarBinOpSection :: (PatRetType, [VName]) -> SrcLoc -> MonoM Exp -desugarBinOpSection op e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (RetType dims rettype, retext) loc = do - (v1, wrap_left, e1, p1) <- makeVarParam e_left $ fromStruct xtype - (v2, wrap_right, e2, p2) <- makeVarParam e_right $ fromStruct ytype +desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (RetType dims rettype, retext) loc = do + t' <- transformTypeSizes t + op <- transformFName loc fname $ toStruct t' + (v1, wrap_left, e1, p1) <- makeVarParam e_left . fromStruct =<< transformTypeSizes xtype + (v2, wrap_right, e2, p2) <- makeVarParam e_right . fromStruct =<< transformTypeSizes ytype let apply_left = mkApply op [(Observe, xext, e1)] - (AppRes (Scalar $ Arrow mempty yp Observe ytype (RetType [] t)) []) - rettype' = - let onDim (NamedSize d) - | Named p <- xp, qualLeaf d == p = NamedSize $ qualName v1 - | Named p <- yp, qualLeaf d == p = NamedSize $ qualName v2 - onDim d = d - in first onDim rettype - body = - mkApply apply_left [(Observe, yext, e2)] (AppRes rettype' retext) + (AppRes (Scalar $ Arrow mempty yp Observe ytype (RetType [] t')) []) + onDim (SizeExpr (Var d typ _)) + | Named p <- xp, qualLeaf d == p = SizeExpr $ Var (qualName v1) typ loc + | Named p <- yp, qualLeaf d == p = SizeExpr $ Var (qualName v2) typ loc + onDim d = d + rettype' = first onDim rettype rettype'' = toStruct rettype' + body <- scoping (S.fromList [v1, v2]) $ mkApply apply_left [(Observe, yext, e2)] <$> transformAppRes (AppRes rettype' retext) + rettype''' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype'' pure $ wrap_left $ wrap_right $ - Lambda (p1 ++ p2) body Nothing (Info (mempty, RetType dims rettype'')) loc + Lambda (p1 ++ p2) body Nothing (Info (mempty, rettype''')) loc where patAndVar argtype = do x <- newNameFromString "x" @@ -597,23 +943,18 @@ desugarProjectSection _ t _ = error $ "desugarOpSection: not a function type: " desugarIndexSection :: [DimIndex] -> PatType -> SrcLoc -> MonoM Exp desugarIndexSection idxs (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do p <- newVName "index_i" - let body = AppExp (Index (Var (qualName p) (Info t1') loc) idxs loc) (Info (AppRes t2 [])) + t1' <- fromStruct <$> transformTypeSizes t1 + t2' <- transformType t2 + let body = AppExp (Index (Var (qualName p) (Info t1') loc) idxs loc) (Info (AppRes t2' [])) pure $ Lambda [Id p (Info (fromStruct t1')) mempty] body Nothing - (Info (mempty, RetType dims $ toStruct t2)) + (Info (mempty, RetType dims $ toStruct t2')) loc - where - t1' = fromStruct t1 desugarIndexSection _ t _ = error $ "desugarIndexSection: not a function type: " ++ prettyString t -noticeDims :: TypeBase Size as -> MonoM () -noticeDims = mapM_ notice . freeInType - where - notice v = void $ transformFName mempty (qualName v) i64 - -- Convert a collection of 'ValBind's to a nested sequence of let-bound, -- monomorphic functions with the given expression at the bottom. unfoldLetFuns :: [ValBind] -> Exp -> Exp @@ -636,7 +977,9 @@ transformPat (Id v (Info (Scalar (Record fs))) loc) = do loc, M.singleton v $ M.fromList $ zip (map fst fs') $ zip fs_ks fs_ts ) -transformPat (Id v t loc) = pure (Id v t loc, mempty) +transformPat (Id v t loc) = do + t' <- traverse transformType t + pure (Id v t' loc, mempty) transformPat (TuplePat pats loc) = do (pats', rrs) <- mapAndUnzipM transformPat pats pure (TuplePat pats' loc, mconcat rrs) @@ -672,28 +1015,49 @@ dimMapping :: Monoid a => TypeBase Size a -> TypeBase Size a -> + ExpReplacements -> + ExpReplacements -> DimInst -dimMapping t1 t2 = execState (matchDims f t1 t2) mempty +dimMapping t1 t2 r1 r2 = execState (matchDims onDims t1 t2) mempty where - f bound d1 (NamedSize d2) - | qualLeaf d2 `elem` bound = pure d1 - f _ (NamedSize d1) d2 = do - modify $ M.insert (qualLeaf d1) d2 - pure $ NamedSize d1 - f _ d _ = pure d - -inferSizeArgs :: [TypeParam] -> StructType -> StructType -> [Exp] -inferSizeArgs tparams bind_t t = - mapMaybe (tparamArg (dimMapping bind_t t)) tparams + revMap = map (\(k, v) -> (v, k)) + named1 = revMap r1 + named2 = revMap r2 + + onDims bound (SizeExpr e1) (SizeExpr e2) = do + onExps bound e1 e2 + pure $ SizeExpr e1 + onDims _ d _ = pure d + + onExps bound (Var v _ _) e = do + unless (any (`elem` bound) $ freeVarsInExp e) $ + modify $ + M.insert (qualLeaf v) $ + SizeExpr e + case lookup (qualLeaf v) named1 of + Just rexp -> onExps bound (unReplaced rexp) e + Nothing -> pure () + onExps bound e (Var v _ _) + | Just rexp <- lookup (qualLeaf v) named2 = onExps bound e (unReplaced rexp) + onExps bound e1 e2 + | Just es <- similarExps e1 e2 = + mapM_ (uncurry $ onExps bound) es + onExps _ _ _ = pure mempty + + freeVarsInExp = M.keys . unFV . freeInExp + +inferSizeArgs :: [TypeParam] -> StructType -> ExpReplacements -> StructType -> MonoM [Exp] +inferSizeArgs tparams bind_t bind_r t = do + r <- get + let dinst = dimMapping bind_t t bind_r r + mapM (tparamArg dinst) tparams where tparamArg dinst tp = case M.lookup (typeParamName tp) dinst of - Just (NamedSize d) -> - Just $ Var d (Info i64) mempty - Just (ConstSize x) -> - Just $ Literal (SignedValue $ Int64Value $ fromIntegral x) mempty + Just (SizeExpr e) -> + replaceExp e _ -> - Just $ Literal (SignedValue $ Int64Value 0) mempty + pure $ Literal (SignedValue $ Int64Value 0) mempty -- Monomorphising higher-order functions can result in function types -- where the same named parameter occurs in multiple spots. When @@ -715,6 +1079,99 @@ noNamedParams = f Sum $ fmap (map f) cs f' t = t +transformRetType :: StructRetType -> MonoM StructRetType +transformRetType (RetType ext t) = RetType ext <$> transformTypeSizes t + +-- | arrowArg takes a return type and returns it +-- with the existentials bound moved at the right of arrows. +-- It also gives the new set of parameters to consider. +arrowArg :: + S.Set VName -> -- scope + S.Set VName -> -- set of argument + [VName] -> -- size parameters + RetTypeBase Size as -> + (RetTypeBase Size as, S.Set VName) +arrowArg scope argset args_params rety = + let (rety', (funArgs, _)) = runWriter (arrowArgRetType (scope, mempty) argset rety) + new_params = funArgs `S.union` S.fromList args_params + in (arrowCleanRetType new_params rety', new_params) + where + -- \| takes a type (or return type) and returns it + -- with the existentials bound moved at the right of arrows. + -- It also gives (through writer monad) size variables used in arrow arguments + -- and variables that are constructively used. + -- The returned type should be cleanned, as too many existentials are introduced. + arrowArgRetType :: + (S.Set VName, [VName]) -> + S.Set VName -> + RetTypeBase Size as' -> + Writer (S.Set VName, S.Set VName) (RetTypeBase Size as') + arrowArgRetType (scope', dimsToPush) argset' (RetType dims ty) = pass $ do + let dims' = dims <> dimsToPush + (ty', (_, canExt)) <- listen $ arrowArgType (argset' `S.union` scope', dims') ty + pure (RetType (filter (`S.member` canExt) dims') ty', first (`S.difference` canExt)) + + arrowArgScalar env (Record fs) = + Record <$> traverse (arrowArgType env) fs + arrowArgScalar env (Sum cs) = + Sum <$> (traverse . traverse) (arrowArgType env) cs + arrowArgScalar (scope', dimsToPush) (Arrow as argName d argT retT) = + pass $ do + let intros = S.filter notIntrisic argset' `S.difference` scope' + retT' <- arrowArgRetType (scope', filter (`S.notMember` intros) dimsToPush) fullArgset retT + pure (Arrow as argName d argT retT', bimap (intros `S.union`) (const mempty)) + where + notIntrisic vn = baseTag vn > maxIntrinsicTag + argset' = fvVars $ freeInType argT + fullArgset = + argset' + <> case argName of + Unnamed -> mempty + Named vn -> S.singleton vn + arrowArgScalar env (TypeVar as uniq qn args) = + TypeVar as uniq qn <$> mapM arrowArgArg args + where + arrowArgArg (TypeArgDim dim) = TypeArgDim <$> arrowArgSize dim + arrowArgArg (TypeArgType ty) = TypeArgType <$> arrowArgType env ty + arrowArgScalar _ ty = pure ty + + arrowArgType :: + (S.Set VName, [VName]) -> + TypeBase Size as' -> + Writer (S.Set VName, S.Set VName) (TypeBase Size as') + arrowArgType env (Array as u shape scalar) = + Array as u <$> traverse arrowArgSize shape <*> arrowArgScalar env scalar + arrowArgType env (Scalar ty) = + Scalar <$> arrowArgScalar env ty + + arrowArgSize s@(SizeExpr (Var qn _ _)) = writer (s, (mempty, S.singleton $ qualLeaf qn)) + arrowArgSize s = pure s + + -- \| arrowClean cleans the mess in the type + arrowCleanRetType :: S.Set VName -> RetTypeBase Size as -> RetTypeBase Size as + arrowCleanRetType paramed (RetType dims ty) = + RetType (nubOrd $ filter (`S.notMember` paramed) dims) (arrowCleanType (paramed `S.union` S.fromList dims) ty) + + arrowCleanScalar :: S.Set VName -> ScalarTypeBase Size as -> ScalarTypeBase Size as + arrowCleanScalar paramed (Record fs) = + Record $ M.map (arrowCleanType paramed) fs + arrowCleanScalar paramed (Sum cs) = + Sum $ (M.map . map) (arrowCleanType paramed) cs + arrowCleanScalar paramed (Arrow as argName d argT retT) = + Arrow as argName d argT (arrowCleanRetType paramed retT) + arrowCleanScalar paramed (TypeVar as uniq qn args) = + TypeVar as uniq qn $ map arrowCleanArg args + where + arrowCleanArg (TypeArgDim dim) = TypeArgDim dim + arrowCleanArg (TypeArgType ty) = TypeArgType $ arrowCleanType paramed ty + arrowCleanScalar _ ty = ty + + arrowCleanType :: S.Set VName -> TypeBase Size as -> TypeBase Size as + arrowCleanType paramed (Array as u shape scalar) = + Array as u shape $ arrowCleanScalar paramed scalar + arrowCleanType paramed (Scalar ty) = + Scalar $ arrowCleanScalar paramed ty + -- Monomorphise a polymorphic function at the types given in the instance -- list. Monomorphises the body of the function as well. Returns the fresh name -- of the generated monomorphic function and its 'ValBind' representation. @@ -723,26 +1180,51 @@ monomorphiseBinding :: PolyBinding -> MonoType -> MonoM (VName, InferSizeArgs, ValBind) -monomorphiseBinding entry (PolyBinding rr (name, tparams, params, rettype, body, attrs, loc)) inst_t = - replaceRecordReplacements rr $ do +monomorphiseBinding entry (PolyBinding rr (name, tparams, params, rettype, body, attrs, loc)) inst_t = do + letFun <- asks $ S.member name . envScope + let paramGetClean argset = + if letFun + then parametrizing argset + else do + ret <- get + put mempty + pure ret + replaceRecordReplacements rr $ (if letFun then id else isolateNormalisation) $ do let bind_t = funType params rettype - (substs, t_shape_params) <- typeSubstsM loc (noSizes bind_t) $ noNamedParams inst_t - let substs' = M.map (Subst []) substs - rettype' = applySubst (`M.lookup` substs') rettype + (substs, t_shape_params) <- + typeSubstsM loc (noSizes bind_t) $ noNamedParams inst_t + let shape_names = S.fromList $ map typeParamName $ shape_params ++ t_shape_params + substs' = M.map (Subst []) substs substPatType = substTypesAny (fmap (fmap (second (const mempty))) . (`M.lookup` substs')) params' = map (substPat entry substPatType) params + (params'', rrs) <- withArgs shape_names $ mapAndUnzipM transformPat params' + exp_naming <- paramGetClean shape_names + + let args = foldMap patNames params + arg_params = map snd exp_naming + + rettype' <- withParams exp_naming (withArgs (args <> shape_names) $ hardTransformRetType rettype) + extNaming <- paramGetClean (args <> shape_names) + scope <- S.union shape_names <$> askScope + let (rettype'', new_params) = arrowArg scope args arg_params rettype' + rettype''' = applySubst (`M.lookup` substs') rettype'' bind_t' = substTypesAny (`M.lookup` substs') bind_t (shape_params_explicit, shape_params_implicit) = partition ((`S.member` mustBeExplicitInBinding bind_t') . typeParamName) $ - shape_params ++ t_shape_params - - (params'', rrs) <- mapAndUnzipM transformPat params' - - mapM_ noticeDims $ retType rettype : map patternStructType params'' + shape_params ++ t_shape_params ++ map (`TypeParamDim` mempty) (S.toList new_params) + exp_naming' = filter ((`S.member` new_params) . snd) (extNaming <> exp_naming) + bind_t'' = funType params'' rettype''' + bind_r = exp_naming <> extNaming body' <- updateExpTypes (`M.lookup` substs') body - body'' <- withRecordReplacements (mconcat rrs) $ transformExp body' + body'' <- withRecordReplacements (mconcat rrs) $ withParams exp_naming' $ withArgs (shape_names <> args) $ transformExp body' + scope' <- S.union (shape_names <> args) <$> askScope + body''' <- + if letFun + then unscoping (shape_names <> args) body'' + else expReplace exp_naming' <$> (calculateDims body'' . canCalculate scope' =<< get) + seen_before <- elem name . map (fst . fst) <$> getLifts name' <- if null tparams && not entry && not seen_before @@ -751,28 +1233,34 @@ monomorphiseBinding entry (PolyBinding rr (name, tparams, params, rettype, body, pure ( name', - inferSizeArgs shape_params_explicit bind_t', + inferSizeArgs shape_params_explicit bind_t'' bind_r, if entry then toValBinding name' (shape_params_explicit ++ shape_params_implicit) params'' - rettype' - body'' + rettype''' + (entryAssert exp_naming body''') else toValBinding name' shape_params_implicit (map shapeParam shape_params_explicit ++ params'') - rettype' - body'' + rettype''' + body''' ) where shape_params = filter (not . isTypeParam) tparams updateExpTypes substs = astMap (mapper substs) + hardTransformRetType (RetType _ ty) = do + ty' <- transformTypeSizes ty + unbounded <- askIntros $ fvVars $ freeInType ty' + let dims' = S.toList unbounded + pure $ RetType dims' ty' + mapper substs = ASTMapper { mapOnExp = updateExpTypes substs, @@ -852,9 +1340,9 @@ typeSubstsM loc orig_t1 orig_t2 = d <- lift $ lift $ newVName "d" tell [TypeParamDim d loc] put (ts, M.insert i d sizes) - pure $ NamedSize $ qualName d + pure $ sizeFromName (qualName d) mempty Just d -> - pure $ NamedSize $ qualName d + pure $ sizeFromName (qualName d) mempty onDim (MonoAnon v) = pure $ AnySize $ Just v -- Perform a given substitution on the types in a pattern. @@ -942,7 +1430,6 @@ transformValBind valbind = do transformTypeBind :: TypeBind -> MonoM Env transformTypeBind (TypeBind name l tparams _ (Info (RetType dims t)) _ _) = do subs <- asks $ M.map substFromAbbr . envTypeBindings - noticeDims t let tbinding = TypeAbbr l tparams $ RetType dims $ applySubst (`M.lookup` subs) t pure mempty {envTypeBindings = M.singleton name tbinding} diff --git a/src/Futhark/Internalise/TypesValues.hs b/src/Futhark/Internalise/TypesValues.hs index b0ba4e08f2..7fb652fe02 100644 --- a/src/Futhark/Internalise/TypesValues.hs +++ b/src/Futhark/Internalise/TypesValues.hs @@ -127,8 +127,9 @@ internaliseDim :: internaliseDim exts d = case d of E.AnySize _ -> Ext <$> newId - E.ConstSize n -> pure $ Free $ intConst I.Int64 $ toInteger n - E.NamedSize name -> pure $ namedDim name + E.SizeExpr (E.IntLit n _ _) -> pure $ Free $ intConst I.Int64 n + E.SizeExpr (E.Var name _ _) -> pure $ namedDim name + E.SizeExpr e -> error $ "Unexpected size expression: " ++ prettyString e where namedDim (E.QualName _ name) | Just x <- name `M.lookup` exts = I.Ext x @@ -152,7 +153,7 @@ internaliseTypeM exts orig_t = | null ets -> pure [I.Prim I.Unit] | otherwise -> concat <$> mapM (internaliseTypeM exts . snd) (E.sortFields ets) - E.Scalar (E.TypeVar _ u tn [E.TypeArgType arr_t _]) + E.Scalar (E.TypeVar _ u tn [E.TypeArgType arr_t]) | baseTag (E.qualLeaf tn) <= E.maxIntrinsicTag, baseString (E.qualLeaf tn) == "acc" -> do ts <- map (fromDecl . onAccType) <$> internaliseTypeM exts arr_t diff --git a/src/Language/Futhark.hs b/src/Language/Futhark.hs index 0f5b3457fa..498a137c6d 100644 --- a/src/Language/Futhark.hs +++ b/src/Language/Futhark.hs @@ -4,26 +4,6 @@ module Language.Futhark module Language.Futhark.Prop, module Language.Futhark.FreeVars, module Language.Futhark.Pretty, - Ident, - DimIndex, - Slice, - AppExp, - Exp, - Pat, - ModExp, - ModParam, - SigExp, - ModBind, - SigBind, - ValBind, - Dec, - Spec, - Prog, - TypeBind, - StructTypeArg, - ScalarType, - TypeParam, - Case, ) where @@ -31,63 +11,3 @@ import Language.Futhark.FreeVars import Language.Futhark.Pretty import Language.Futhark.Prop import Language.Futhark.Syntax - --- | An identifier with type- and aliasing information. -type Ident = IdentBase Info VName - --- | An index with type information. -type DimIndex = DimIndexBase Info VName - --- | A slice with type information. -type Slice = SliceBase Info VName - --- | An expression with type information. -type Exp = ExpBase Info VName - --- | An application expression with type information. -type AppExp = AppExpBase Info VName - --- | A pattern with type information. -type Pat = PatBase Info VName - --- | An constant declaration with type information. -type ValBind = ValBindBase Info VName - --- | A type binding with type information. -type TypeBind = TypeBindBase Info VName - --- | A type-checked module binding. -type ModBind = ModBindBase Info VName - --- | A type-checked module type binding. -type SigBind = SigBindBase Info VName - --- | A type-checked module expression. -type ModExp = ModExpBase Info VName - --- | A type-checked module parameter. -type ModParam = ModParamBase Info VName - --- | A type-checked module type expression. -type SigExp = SigExpBase Info VName - --- | A type-checked declaration. -type Dec = DecBase Info VName - --- | A type-checked specification. -type Spec = SpecBase Info VName - --- | An Futhark program with type information. -type Prog = ProgBase Info VName - --- | A known type arg with shape annotations. -type StructTypeArg = TypeArg Size - --- | A type-checked type parameter. -type TypeParam = TypeParamBase VName - --- | A known scalar type with no shape annotations. -type ScalarType = ScalarTypeBase () - --- | A type-checked case (of a match expression). -type Case = CaseBase Info VName diff --git a/src/Language/Futhark/FreeVars.hs b/src/Language/Futhark/FreeVars.hs index 65cc3a46cd..5d72397ae1 100644 --- a/src/Language/Futhark/FreeVars.hs +++ b/src/Language/Futhark/FreeVars.hs @@ -6,6 +6,7 @@ module Language.Futhark.FreeVars freeInType, freeWithout, FV (..), + fvVars, ) where @@ -18,6 +19,10 @@ import Language.Futhark.Syntax newtype FV = FV {unFV :: M.Map VName StructType} deriving (Show) +-- | The set of names in an 'FV'. +fvVars :: FV -> S.Set VName +fvVars = M.keysSet . unFV + instance Semigroup FV where FV x <> FV y = FV $ M.unionWith max x y @@ -33,13 +38,6 @@ freeWithout (FV x) y = FV $ M.filterWithKey keep x ident :: IdentBase Info VName -> FV ident v = FV $ M.singleton (identName v) (toStruct $ unInfo (identType v)) -size :: VName -> FV -size v = FV $ M.singleton v $ Scalar $ Prim $ Signed Int64 - --- | A 'FV' with these names, considered to be sizes. -sizes :: S.Set VName -> FV -sizes = foldMap size - -- | Compute the set of free variables of an expression. freeInExp :: ExpBase Info VName -> FV freeInExp expr = case expr of @@ -56,20 +54,20 @@ freeInExp expr = case expr of freeInExpField (RecordFieldExplicit _ e _) = freeInExp e freeInExpField (RecordFieldImplicit vn t _) = ident $ Ident vn t mempty ArrayLit es t _ -> - foldMap freeInExp es <> sizes (freeInType $ unInfo t) + foldMap freeInExp es <> freeInType (unInfo t) AppExp (Range e me incl _) _ -> freeInExp e <> foldMap freeInExp me <> foldMap freeInExp incl Var qn (Info t) _ -> FV $ M.singleton (qualLeaf qn) $ toStruct t Ascript e _ _ -> freeInExp e AppExp (Coerce e _ _) (Info ar) -> - freeInExp e <> sizes (freeInType (appResType ar)) + freeInExp e <> freeInType (appResType ar) AppExp (LetPat let_sizes pat e1 e2 _) _ -> freeInExp e1 - <> ( (sizes (freeInPat pat) <> freeInExp e2) + <> ( (freeInPat pat <> freeInExp e2) `freeWithout` (patNames pat <> S.fromList (map sizeName let_sizes)) ) AppExp (LetFun vn (tparams, pats, _, _, e1) e2 _) _ -> - ( (freeInExp e1 <> sizes (foldMap freeInPat pats)) + ( (freeInExp e1 <> foldMap freeInPat pats) `freeWithout` ( foldMap patNames pats <> S.fromList (map typeParamName tparams) ) @@ -80,7 +78,7 @@ freeInExp expr = case expr of Negate e _ -> freeInExp e Not e _ -> freeInExp e Lambda pats e0 _ (Info (_, RetType dims t)) _ -> - (sizes (foldMap freeInPat pats) <> freeInExp e0 <> sizes (freeInType t)) + (foldMap freeInPat pats <> freeInExp e0 <> freeInType t) `freeWithout` (foldMap patNames pats <> S.fromList dims) OpSection {} -> mempty OpSectionLeft _ _ e _ _ _ -> freeInExp e @@ -116,7 +114,7 @@ freeInExp expr = case expr of AppExp (Match e cs _) _ -> freeInExp e <> foldMap caseFV cs where caseFV (CasePat p eCase _) = - (sizes (freeInPat p) <> freeInExp eCase) + (freeInPat p <> freeInExp eCase) `freeWithout` patNames p freeInDimIndex :: DimIndexBase Info VName -> FV @@ -125,7 +123,7 @@ freeInDimIndex (DimSlice me1 me2 me3) = foldMap (foldMap freeInExp) [me1, me2, me3] -- | Free variables in pattern (including types of the bound identifiers). -freeInPat :: PatBase Info VName -> S.Set VName +freeInPat :: PatBase Info VName -> FV freeInPat (TuplePat ps _) = foldMap freeInPat ps freeInPat (RecordPat fs _) = foldMap (freeInPat . snd) fs freeInPat (PatParens p _) = freeInPat p @@ -137,7 +135,7 @@ freeInPat (PatConstr _ _ ps _) = foldMap freeInPat ps freeInPat (PatAttr _ p _) = freeInPat p -- | Free variables in the type (meaning those that are used in size expression). -freeInType :: TypeBase Size as -> S.Set VName +freeInType :: TypeBase Size as -> FV freeInType t = case t of Array _ _ s a -> @@ -149,15 +147,18 @@ freeInType t = Scalar (Sum cs) -> foldMap (foldMap freeInType) cs Scalar (Arrow _ v _ t1 (RetType dims t2)) -> - S.filter (notV v) $ S.filter (`notElem` dims) $ freeInType t1 <> freeInType t2 + FV $ + M.filterWithKey (\k _ -> notV v k && notElem k dims) $ + unFV $ + freeInType t1 <> freeInType t2 Scalar (TypeVar _ _ _ targs) -> foldMap typeArgDims targs where - typeArgDims (TypeArgDim d _) = onSize d - typeArgDims (TypeArgType at _) = freeInType at + typeArgDims (TypeArgDim d) = onSize d + typeArgDims (TypeArgType at) = freeInType at notV Unnamed = const True notV (Named v) = (/= v) - onSize (NamedSize qn) = S.singleton $ qualLeaf qn + onSize (SizeExpr e) = freeInExp e onSize _ = mempty diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index e64171e315..1ae3eaa171 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -42,10 +42,13 @@ import Data.List isPrefixOf, transpose, ) +import Data.List qualified as L import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe import Data.Monoid hiding (Sum) +import Data.Ord +import Data.Set qualified as S import Data.Text qualified as T import Futhark.Data qualified as V import Futhark.Util (chunk, maybeHead, splitFromEnd) @@ -146,37 +149,65 @@ extSizeEnv :: EvalM Env extSizeEnv = i64Env <$> getSizes valueStructType :: ValueType -> StructType -valueStructType = first (ConstSize . fromIntegral) - -resolveTypeParams :: [VName] -> StructType -> StructType -> Env -resolveTypeParams names = match +valueStructType = first $ flip sizeFromInteger mempty . toInteger + +resolveTypeParams :: + [VName] -> + StructType -> + StructType -> + ([(VName, ([VName], StructType))], [(VName, Exp)]) +resolveTypeParams names orig_t1 orig_t2 = + execState (match mempty orig_t1 orig_t2) mempty where - match (Scalar (TypeVar _ _ tn _)) t - | qualLeaf tn `elem` names = - typeEnv $ M.singleton (qualLeaf tn) t - match (Scalar (Record poly_fields)) (Scalar (Record fields)) = - mconcat $ - M.elems $ - M.intersectionWith match poly_fields fields - match (Scalar (Sum poly_fields)) (Scalar (Sum fields)) = - mconcat $ - map mconcat $ - M.elems $ - M.intersectionWith (zipWith match) poly_fields fields + addType v t = modify $ first $ L.insertBy (comparing fst) (v, t) + addDim v e = modify $ second $ L.insertBy (comparing fst) (v, e) + + match bound (Scalar (TypeVar _ _ tn _)) t + | qualLeaf tn `elem` names = addType (qualLeaf tn) (bound, t) + match bound (Scalar (Record poly_fields)) (Scalar (Record fields)) = + sequence_ . M.elems $ + M.intersectionWith (match bound) poly_fields fields + match bound (Scalar (Sum poly_fields)) (Scalar (Sum fields)) = + sequence_ . mconcat . M.elems $ + M.intersectionWith (zipWith $ match bound) poly_fields fields match - (Scalar (Arrow _ _ _ poly_t1 (RetType _ poly_t2))) - (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = - match poly_t1 t1 <> match poly_t2 t2 - match poly_t t + bound + (Scalar (Arrow _ p1 _ poly_t1 (RetType dims1 poly_t2))) + (Scalar (Arrow _ p2 _ t1 (RetType dims2 t2))) = do + let bound' = mapMaybe paramName [p1, p2] <> dims1 <> dims2 <> bound + match bound' poly_t1 t1 + match bound' poly_t2 t2 + match bound poly_t t | d1 : _ <- shapeDims (arrayShape poly_t), - d2 : _ <- shapeDims (arrayShape t) = - matchDims d1 d2 <> match (stripArray 1 poly_t) (stripArray 1 t) - match _ _ = mempty - - matchDims (NamedSize (QualName _ d1)) (ConstSize d2) - | d1 `elem` names = - i64Env $ M.singleton d1 $ fromIntegral d2 - matchDims _ _ = mempty + d2 : _ <- shapeDims (arrayShape t) = do + matchDims bound d1 d2 + match bound (stripArray 1 poly_t) (stripArray 1 t) + match bound t1 t2 + | Just t1' <- isAccType t1, + Just t2' <- isAccType t2 = + match bound t1' t2' + match _ _ _ = pure mempty + + matchDims bound (SizeExpr e1) (SizeExpr e2) = matchExps bound e1 e2 + matchDims _ _ _ = pure mempty + + matchExps bound (Var (QualName _ d1) _ _) e + | d1 `elem` names, + not $ any (`elem` bound) $ freeVarsInExp e = + addDim d1 e + matchExps bound e1 e2 + | Just es <- similarExps e1 e2 = + mapM_ (uncurry $ matchExps bound) es + matchExps _ _ _ = pure mempty + +evalResolved :: + Eval -> + ([(VName, ([VName], StructType))], [(VName, Exp)]) -> + EvalM Env +evalResolved eval' (ts, ds) = do + ts' <- mapM (traverse $ \(bound, t) -> evalType eval' (S.fromList bound) t) ts + ds' <- mapM (traverse $ fmap asInt64 . eval') ds + pure $ typeEnv (M.fromList ts') <> i64Env (M.fromList ds') resolveExistentials :: [VName] -> StructType -> ValueShape -> M.Map VName Int64 resolveExistentials names = match @@ -195,7 +226,7 @@ resolveExistentials names = match matchDims d1 d2 <> match (stripArray 1 poly_t) rowshape match _ _ = mempty - matchDims (NamedSize (QualName _ d1)) d2 + matchDims (SizeExpr (Var (QualName _ d1) _ _)) d2 | d1 `elem` names = M.singleton d1 d2 matchDims _ _ = mempty @@ -261,11 +292,16 @@ lookupVar = lookupInEnv envTerm lookupType :: QualName VName -> Env -> Maybe T.TypeBinding lookupType = lookupInEnv envType +-- | An expression evaluator that embeds an environment. +type Eval = Exp -> EvalM Value + -- | A TermValue with a 'Nothing' type annotation is an intrinsic. data TermBinding = TermValue (Maybe T.BoundV) Value - | -- | A polymorphic value that must be instantiated. - TermPoly (Maybe T.BoundV) (StructType -> EvalM Value) + | -- | A polymorphic value that must be instantiated. The + -- 'StructType' provided is un-evaluated, but parts of it can be + -- evaluated using the provided 'Eval' function. + TermPoly (Maybe T.BoundV) (StructType -> Eval -> EvalM Value) | TermModule Module data Module @@ -545,6 +581,9 @@ evalIndex loc env is arr = do <> "." maybe oob pure $ indexArray is arr +freeVarsInExp :: Exp -> [VName] +freeVarsInExp = M.keys . unFV . freeInExp + -- | Expand type based on information that was not available at -- type-checking time (the structure of abstract types). expandType :: Env -> StructType -> StructType @@ -556,85 +595,113 @@ expandType env t@(Array _ u shape _) = let et = stripArray (shapeRank shape) t et' = expandType env et in arrayOf u shape et' -expandType env t@(Scalar (TypeVar () _ tn args)) = +expandType env (Scalar (TypeVar () u tn args)) = case lookupType tn env of - Just (T.TypeAbbr _ ps (RetType _ t')) -> + Just (T.TypeAbbr _ ps (RetType ext t')) -> let (substs, types) = mconcat $ zipWith matchPtoA ps args - onDim (NamedSize v) = fromMaybe (NamedSize v) $ M.lookup (qualLeaf v) substs + onDim (SizeExpr (Var v _ _)) + | Just e <- M.lookup (qualLeaf v) substs = + e + -- The next case can occur when a type with existential size + -- has been hidden by a module ascription, + -- e.g. tests/modules/sizeparams4.fut. + onDim (SizeExpr e) + | any (`elem` ext) $ freeVarsInExp e = + AnySize Nothing onDim d = d in if null ps then first onDim t' else expandType (Env mempty types <> env) $ first onDim t' - Nothing -> t + Nothing -> + -- This case only happens for built-in abstract types, + -- e.g. accumulators. + Scalar (TypeVar () u tn $ map expandArg args) where - matchPtoA (TypeParamDim p _) (TypeArgDim (NamedSize qv) _) = - (M.singleton p $ NamedSize qv, mempty) - matchPtoA (TypeParamDim p _) (TypeArgDim (ConstSize k) _) = - (M.singleton p $ ConstSize k, mempty) - matchPtoA (TypeParamType l p _) (TypeArgType t' _) = + matchPtoA (TypeParamDim p _) (TypeArgDim (SizeExpr e)) = + (M.singleton p $ SizeExpr e, mempty) + matchPtoA (TypeParamType l p _) (TypeArgType t') = let t'' = expandType env t' in (mempty, M.singleton p $ T.TypeAbbr l [] $ RetType [] t'') matchPtoA _ _ = mempty + expandArg (TypeArgDim s) = TypeArgDim s + expandArg (TypeArgType t) = TypeArgType $ expandType env t expandType env (Scalar (Sum cs)) = Scalar $ Sum $ (fmap . fmap) (expandType env) cs --- | First expand type abbreviations, then evaluate all possible --- sizes. -evalType :: Env -> StructType -> EvalM StructType -evalType outer_env t = do +evalWithExts :: Env -> EvalM Eval +evalWithExts env = do size_env <- extSizeEnv - let env = size_env <> outer_env - evalDim (NamedSize qn) - | Just (TermValue _ (ValuePrim (SignedValue (Int64Value x)))) <- - lookupVar qn env = - ConstSize $ fromIntegral x - evalDim d = d - pure $ first evalDim $ expandType env t + pure $ eval $ size_env <> env + +-- | Evaluate all possible sizes, except those that contain free +-- variables in the set of names. +evalType :: Eval -> S.Set VName -> StructType -> EvalM StructType +evalType eval' outer_bound t = do + let evalDim bound _ (SizeExpr e) + | canBeEvaluated bound e = do + x <- asInteger <$> eval' e + pure $ SizeExpr $ IntLit x (Info (Scalar (Prim (Signed Int64)))) mempty + evalDim _ _ d = pure d + traverseDims evalDim t + where + canBeEvaluated bound e = + let free = freeVarsInExp e + in not $ any (`S.member` bound) free || any (`S.member` outer_bound) free evalTermVar :: Env -> QualName VName -> StructType -> EvalM Value evalTermVar env qv t = case lookupVar qv env of - Just (TermPoly _ v) -> v =<< evalType env t + Just (TermPoly _ v) -> v (expandType env t) =<< evalWithExts env Just (TermValue _ v) -> pure v - _ -> error $ "\"" <> prettyString qv <> "\" is not bound to a value." + _ -> do + ss <- map (locText . srclocOf) <$> stacktrace + error $ + prettyString qv + <> " is not bound to a value.\n" + <> T.unpack (prettyStacktrace 0 ss) typeValueShape :: Env -> StructType -> EvalM ValueShape typeValueShape env t = do - t' <- evalType env t + eval' <- evalWithExts env + t' <- evalType eval' mempty $ expandType env t case traverse dim $ typeShape t' of Nothing -> error $ "typeValueShape: failed to fully evaluate type " <> prettyString t' Just shape -> pure shape where - dim (ConstSize x) = Just $ fromIntegral x + dim (SizeExpr (IntLit x _ _)) = Just $ fromIntegral x dim _ = Nothing +-- Sometimes type instantiation is not quite enough - then we connect +-- up the missing sizes here. In particular used for eta-expanded +-- entry points. +linkMissingSizes :: [VName] -> Pat -> Value -> Env -> Env +linkMissingSizes [] _ _ env = env +linkMissingSizes missing_sizes p v env = + env <> i64Env (resolveExistentials missing_sizes p_t (valueShape v)) + where + p_t = expandType env $ patternStructType p + evalFunction :: Env -> [VName] -> [Pat] -> Exp -> StructType -> EvalM Value -- We treat zero-parameter lambdas as simply an expression to -- evaluate immediately. Note that this is *not* the same as a lambda -- that takes an empty tuple '()' as argument! Zero-parameter lambdas -- can never occur in a well-formed Futhark program, but they are -- convenient in the interpreter. -evalFunction env _ [] body rettype = +evalFunction env missing_sizes [] body rettype = -- Eta-expand the rest to make any sizes visible. etaExpand [] env rettype where - etaExpand vs env' (Scalar (Arrow _ _ _ pt (RetType _ rt))) = + etaExpand vs env' (Scalar (Arrow _ _ _ p_t (RetType _ rt))) = do pure . ValueFun $ \v -> do - env'' <- matchPat env' (Wildcard (Info $ fromStruct pt) noLoc) v + let p = Wildcard (Info $ fromStruct p_t) noLoc + env'' <- linkMissingSizes missing_sizes p v <$> matchPat env' p v etaExpand (v : vs) env'' rt etaExpand vs env' _ = do f <- localExts $ eval env' body foldM (apply noLoc mempty) f $ reverse vs evalFunction env missing_sizes (p : ps) body rettype = pure . ValueFun $ \v -> do - env' <- matchPat env p v - -- Fix up the last sizes, if any. - let p_t = expandType env $ patternStructType p - env'' - | null missing_sizes = - env' - | otherwise = - env' <> i64Env (resolveExistentials missing_sizes p_t (valueShape v)) - evalFunction env'' missing_sizes ps body rettype + env' <- linkMissingSizes missing_sizes p v <$> matchPat env p v + evalFunction env' missing_sizes ps body rettype evalFunctionBinding :: Env -> @@ -656,19 +723,20 @@ evalFunctionBinding env tparams ps ret fbody = do fmap (TermValue (Just $ T.BoundV [] ftype)) . returned env (retType ret) retext =<< evalFunction env [] ps fbody (retType ret) - else pure . TermPoly (Just $ T.BoundV [] ftype) $ \ftype' -> do - let tparam_names = map typeParamName tparams - env' = resolveTypeParams tparam_names ftype ftype' <> env - - -- In some cases (abstract lifted types) there may be - -- missing sizes that were not fixed by the type - -- instantiation. These will have to be set by looking - -- at the actual function arguments. - missing_sizes = - filter (`M.notMember` envTerm env') $ - map typeParamName (filter isSizeParam tparams) - returned env (retType ret) retext - =<< evalFunction env' missing_sizes ps fbody (retType ret) + else pure . TermPoly (Just $ T.BoundV [] ftype) $ \ftype' -> + let resolved = resolveTypeParams (map typeParamName tparams) ftype ftype' + in \eval' -> do + tparam_env <- evalResolved eval' resolved + let env' = tparam_env <> env + -- In some cases (abstract lifted types) there may be + -- missing sizes that were not fixed by the type + -- instantiation. These will have to be set by looking + -- at the actual function arguments. + missing_sizes = + filter (`M.notMember` envTerm env') $ + map typeParamName (filter isSizeParam tparams) + returned env (retType ret) retext + =<< evalFunction env' missing_sizes ps fbody (retType ret) evalArg :: Env -> Exp -> Maybe VName -> EvalM Value evalArg env e ext = do @@ -737,7 +805,9 @@ evalAppExp env _ (Range start maybe_second end loc) = do <> " is invalid." evalAppExp env t (Coerce e te loc) = do v <- eval env e - case checkShape (structTypeShape t) (valueShape v) of + eval' <- evalWithExts env + t' <- evalType eval' mempty $ expandType env t + case checkShape (structTypeShape t') (valueShape v) of Just _ -> pure v Nothing -> bad loc env . docText $ @@ -748,7 +818,7 @@ evalAppExp env t (Coerce e te loc) = do <> "` cannot match shape of type `" <> pretty te <> "` (`" - <> pretty t + <> pretty t' <> "`)" evalAppExp env _ (LetPat sizes p e body _) = do v <- eval env e @@ -775,9 +845,9 @@ evalAppExp then pure $ ValuePrim $ BoolValue True else eval env y | otherwise = do - op' <- eval env $ Var op op_t loc x' <- evalArg env x xext y' <- evalArg env y yext + op' <- eval env $ Var op op_t loc apply2 loc env op' x' y' evalAppExp env _ (If cond e1 e2 _) = do cond' <- asBool <$> eval env cond @@ -902,8 +972,9 @@ eval env (ArrayLit (v : vs) _ _) = do vs' <- mapM (eval env) vs pure $ toArray' (valueShape v') (v' : vs') eval env (AppExp e (Info (AppRes t retext))) = do - t' <- evalType env $ toStruct t - returned env t' retext =<< evalAppExp env t' e + let t' = expandType env $ toStruct t + v <- evalAppExp env t' e + returned env t' retext v eval env (Var qv (Info t) _) = evalTermVar env qv (toStruct t) eval env (Ascript e _ _) = eval env e eval _ (IntLit v (Info t) _) = @@ -1044,9 +1115,10 @@ substituteInModule substs = onModule onTerm (TermPoly t v) = TermPoly t v onTerm (TermModule m) = TermModule $ onModule m onType (T.TypeAbbr l ps t) = T.TypeAbbr l ps $ first onDim t - onDim (NamedSize v) = NamedSize $ replaceQ v - onDim (ConstSize x) = ConstSize x - onDim (AnySize v) = AnySize v + onDim (SizeExpr (Var v typ loc)) = SizeExpr (Var (replaceQ v) typ loc) + onDim (SizeExpr (IntLit x t loc)) = SizeExpr (IntLit x t loc) + onDim (SizeExpr _) = error "Arbitrary expression not supported yet" + onDim AnySize {} = error "substituteInModule onDim: AnySize" evalModuleVar :: Env -> QualName VName -> EvalM Module evalModuleVar env qv = @@ -1095,9 +1167,11 @@ evalModExp env (ModApply f e (Info psubst) (Info rsubst) _) = do _ -> error "Expected ModuleFun." evalDec :: Env -> Dec -> EvalM Env -evalDec env (ValDec (ValBind _ v _ (Info ret) tparams ps fbody _ _ _)) = do +evalDec env (ValDec (ValBind _ v _ (Info ret) tparams ps fbody _ _ _)) = localExts $ do binding <- evalFunctionBinding env tparams ps ret fbody - pure $ env {envTerm = M.insert v binding $ envTerm env} + sizes <- extSizeEnv + pure $ + env {envTerm = M.insert v binding $ envTerm env} <> sizes evalDec env (OpenDec me _) = do me' <- evalModExp env me case me' of @@ -1472,9 +1546,10 @@ initialCtx = def s | "reduce_stream" `isPrefixOf` s = Just $ fun3 $ \_ f arg -> stream f arg def "map" = Just $ - TermPoly Nothing $ \t -> pure $ - ValueFun $ \f -> pure . ValueFun $ \xs -> - case unfoldFunType t of + TermPoly Nothing $ \t eval' -> do + t' <- evalType eval' mempty t + pure $ ValueFun $ \f -> pure . ValueFun $ \xs -> + case unfoldFunType t' of ([_, _], ret_t) | Just rowshape <- typeRowShape ret_t -> toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) @@ -1600,10 +1675,10 @@ initialCtx = ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr ) -> do - let acc = ValueAcc (\_ x -> pure x) dest_arr + let acc = ValueAcc dest_shape (\_ x -> pure x) dest_arr acc' <- foldM (apply2 noLoc mempty f) acc vs_arr case acc' of - ValueAcc _ dest_arr' -> + ValueAcc _ _ dest_arr' -> pure $ ValueArray dest_shape dest_arr' _ -> error $ "scatter_stream produced: " <> show acc' @@ -1615,10 +1690,10 @@ initialCtx = ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr ) -> do - let acc = ValueAcc (apply2 noLoc mempty op) dest_arr + let acc = ValueAcc dest_shape (apply2 noLoc mempty op) dest_arr acc' <- foldM (apply2 noLoc mempty f) acc vs_arr case acc' of - ValueAcc _ dest_arr' -> + ValueAcc _ _ dest_arr' -> pure $ ValueArray dest_shape dest_arr' _ -> error $ "hist_stream produced: " <> show acc' @@ -1627,14 +1702,14 @@ initialCtx = def "acc_write" = Just $ fun3 $ \acc i v -> case (acc, i) of - ( ValueAcc op acc_arr, + ( ValueAcc shape op acc_arr, ValuePrim (SignedValue (Int64Value i')) ) -> if i' >= 0 && i' < arrayLength acc_arr then do let x = acc_arr ! fromIntegral i' res <- op x v - pure $ ValueAcc op $ acc_arr // [(fromIntegral i', res)] + pure $ ValueAcc shape op $ acc_arr // [(fromIntegral i', res)] else pure acc _ -> error $ "acc_write invalid arguments: " <> prettyString (show acc, show i, show v) @@ -1846,20 +1921,24 @@ initialCtx = interpretExp :: Ctx -> Exp -> F ExtOp Value interpretExp ctx e = runEvalM (ctxImports ctx) $ eval (ctxEnv ctx) e -interpretDec :: Ctx -> Dec -> F ExtOp Ctx -interpretDec ctx d = do - env <- runEvalM (ctxImports ctx) $ do - env <- evalDec (ctxEnv ctx) d +interpretDecs :: Ctx -> [Dec] -> F ExtOp Env +interpretDecs ctx decs = + runEvalM (ctxImports ctx) $ do + env <- foldM evalDec (ctxEnv ctx) decs -- We need to extract any new existential sizes and add them as -- ordinary bindings to the context, or we will not be able to -- look up their values later. sizes <- extSizeEnv pure $ env <> sizes + +interpretDec :: Ctx -> Dec -> F ExtOp Ctx +interpretDec ctx d = do + env <- interpretDecs ctx [d] pure ctx {ctxEnv = env} interpretImport :: Ctx -> (ImportName, Prog) -> F ExtOp Ctx interpretImport ctx (fp, prog) = do - env <- runEvalM (ctxImports ctx) $ foldM evalDec (ctxEnv ctx) $ progDecs prog + env <- interpretDecs ctx $ progDecs prog pure ctx {ctxImports = M.insert fp env $ ctxImports ctx} -- | Produce a context, based on the one passed in, where all of @@ -1922,7 +2001,11 @@ interpretFunction ctx fname vs = do Right $ runEvalM (ctxImports ctx) $ do - f <- evalTermVar (ctxEnv ctx) (qualName fname) ft + -- XXX: We are providing a dummy type here. This is OK as long + -- as the function we invoke is monomorphic, which is what we + -- require of entry points. This is to avoid reimplementing + -- type inference machinery here. + f <- evalTermVar (ctxEnv ctx) (qualName fname) (Scalar (Prim Bool)) foldM (apply noLoc mempty) f vs' where updateType (vt : vts) (Scalar (Arrow als pn d pt (RetType dims rt))) = do diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index 66ffb9a8ef..bb08d6c52d 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -83,13 +83,16 @@ typeShape (Scalar (Record fs)) = ShapeRecord $ M.map typeShape fs typeShape (Scalar (Sum cs)) = ShapeSum $ M.map (map typeShape) cs -typeShape _ = - ShapeLeaf +typeShape t + | Just t' <- isAccType t = + typeShape t' + | otherwise = + ShapeLeaf structTypeShape :: StructType -> Shape (Maybe Int64) structTypeShape = fmap dim . typeShape where - dim (ConstSize d) = Just $ fromIntegral d + dim (SizeExpr (IntLit x _ _)) = Just $ fromIntegral x dim _ = Nothing -- | A fully evaluated Futhark value. @@ -101,8 +104,8 @@ data Value m | ValueFun (Value m -> m (Value m)) | -- Stores the full shape. ValueSum ValueShape Name [Value m] - | -- The update function and the array. - ValueAcc (Value m -> Value m -> m (Value m)) !(Array Int (Value m)) + | -- The shape, the update function, and the array. + ValueAcc ValueShape (Value m -> Value m -> m (Value m)) !(Array Int (Value m)) instance Show (Value m) where show (ValuePrim v) = "ValuePrim " <> show v <> "" @@ -124,7 +127,7 @@ instance Eq (Value m) where ValueArray _ x == ValueArray _ y = x == y ValueRecord x == ValueRecord y = x == y ValueSum _ n1 vs1 == ValueSum _ n2 vs2 = n1 == n2 && vs1 == vs2 - ValueAcc _ x == ValueAcc _ y = x == y + ValueAcc _ _ x == ValueAcc _ _ y = x == y _ == _ = False prettyValueWith :: (PrimValue -> Doc a) -> Value m -> Doc a @@ -174,6 +177,7 @@ valueText = docText . prettyValueWith pretty valueShape :: Value m -> ValueShape valueShape (ValueArray shape _) = shape +valueShape (ValueAcc shape _ _) = shape valueShape (ValueRecord fs) = ShapeRecord $ M.map valueShape fs valueShape (ValueSum shape _ _) = shape valueShape _ = ShapeLeaf diff --git a/src/Language/Futhark/Pretty.hs b/src/Language/Futhark/Pretty.hs index e98ff3b443..d2e194cc7a 100644 --- a/src/Language/Futhark/Pretty.hs +++ b/src/Language/Futhark/Pretty.hs @@ -20,6 +20,7 @@ import Data.Map.Strict qualified as M import Data.Maybe import Data.Monoid hiding (Sum) import Data.Ord +import Data.Text qualified as T import Data.Word import Futhark.Util import Futhark.Util.Pretty @@ -83,8 +84,7 @@ instance Pretty PrimValue where instance Pretty Size where pretty (AnySize Nothing) = mempty pretty (AnySize (Just v)) = "?" <> prettyName v - pretty (NamedSize v) = pretty v - pretty (ConstSize n) = pretty n + pretty (SizeExpr e) = pretty e instance (Eq vn, IsName vn, Annot f) => Pretty (SizeExp f vn) where pretty SizeExpAny {} = brackets mempty @@ -161,8 +161,8 @@ instance Pretty (Shape dim) => Pretty (TypeBase dim as) where pretty = prettyType 0 prettyTypeArg :: Pretty (Shape dim) => Int -> TypeArg dim -> Doc a -prettyTypeArg _ (TypeArgDim d _) = pretty $ Shape [d] -prettyTypeArg p (TypeArgType t _) = prettyType p t +prettyTypeArg _ (TypeArgDim d) = pretty $ Shape [d] +prettyTypeArg p (TypeArgType t) = prettyType p t instance Pretty (TypeArg Size) where pretty = prettyTypeArg 0 @@ -321,8 +321,18 @@ prettyInst t = prettyAttr :: Pretty a => a -> Doc ann prettyAttr attr = "#[" <> pretty attr <> "]" +operatorName :: Name -> Bool +operatorName = (`elem` opchars) . T.head . nameToText + where + opchars :: String + opchars = "+-*/%=!><|&^." + prettyExp :: (Eq vn, IsName vn, Annot f) => Int -> ExpBase f vn -> Doc a -prettyExp _ (Var name t _) = pretty name <> prettyInst t +prettyExp _ (Var name t _) + -- The first case occurs only for programs that have been normalised + -- by the compiler. + | operatorName (toName (qualLeaf name)) = parens $ pretty name <> prettyInst t + | otherwise = pretty name <> prettyInst t prettyExp _ (Hole t _) = "???" <> prettyInst t prettyExp _ (Parens e _) = align $ parens $ pretty e prettyExp _ (QualParens (v, _) e _) = pretty v <> "." <> align (parens $ pretty e) diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 01b77f5efa..854244789f 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -6,6 +6,7 @@ module Language.Futhark.Prop ( -- * Various Intrinsic (..), intrinsics, + intrinsicVar, isBuiltin, isBuiltinLoc, maxIntrinsicTag, @@ -28,6 +29,8 @@ module Language.Futhark.Prop valBindTypeScheme, valBindBound, funType, + stripExp, + similarExps, -- * Queries on patterns and params patIdents, @@ -50,6 +53,7 @@ module Language.Futhark.Prop foldFunType, foldFunTypeFromParams, typeVars, + isAccType, -- * Operations on types peelArray, @@ -94,6 +98,27 @@ module Language.Futhark.Prop UncheckedSpec, UncheckedProg, UncheckedCase, + -- | Type-checked syntactical constructs + Ident, + DimIndex, + Slice, + AppExp, + Exp, + Pat, + ModExp, + ModParam, + SigExp, + ModBind, + SigBind, + ValBind, + Dec, + Spec, + Prog, + TypeBind, + StructTypeArg, + ScalarType, + TypeParam, + Case, ) where @@ -105,6 +130,7 @@ import Data.Bitraversable (bitraverse) import Data.Char import Data.Foldable import Data.List (genericLength, isPrefixOf, sortOn) +import Data.List.NonEmpty qualified as NE import Data.Loc (Loc (..), posFile) import Data.Map.Strict qualified as M import Data.Maybe @@ -185,10 +211,10 @@ traverseDims f = go mempty PosImmediate Named p' -> S.insert p' bound Unnamed -> bound - onTypeArg bound b (TypeArgDim d loc) = - TypeArgDim <$> f bound b d <*> pure loc - onTypeArg bound b (TypeArgType t loc) = - TypeArgType <$> go bound b t <*> pure loc + onTypeArg bound b (TypeArgDim d) = + TypeArgDim <$> f bound b d + onTypeArg bound b (TypeArgType t) = + TypeArgType <$> go bound b t -- | Return the uniqueness of a type. uniqueness :: TypeBase shape as -> Uniqueness @@ -347,8 +373,7 @@ combineTypeShapes (Scalar (Arrow als1 p1 d1 a1 (RetType dims1 b1))) (Scalar (Arr combineTypeShapes (Scalar (TypeVar als1 u1 v targs1)) (Scalar (TypeVar als2 _ _ targs2)) = Scalar $ TypeVar (als1 <> als2) u1 v $ zipWith f targs1 targs2 where - f (TypeArgType t1 loc) (TypeArgType t2 _) = - TypeArgType (combineTypeShapes t1 t2) loc + f (TypeArgType t1) (TypeArgType t2) = TypeArgType (combineTypeShapes t1 t2) f targ _ = targ combineTypeShapes (Array als1 u1 shape1 et1) (Array als2 _u2 _shape2 et2) = arrayOfWithAliases @@ -410,8 +435,8 @@ matchDims onDims = matchDims' mempty _ -> pure t1 matchTypeArg _ ta@TypeArgType {} _ = pure ta - matchTypeArg bound (TypeArgDim x loc) (TypeArgDim y _) = - TypeArgDim <$> onDims bound x y <*> pure loc + matchTypeArg bound (TypeArgDim x) (TypeArgDim y) = + TypeArgDim <$> onDims bound x y matchTypeArg _ a _ = pure a onShapes bound shape1 shape2 = @@ -480,11 +505,11 @@ typeOf (RecordLit fs _) = t `addAliases` S.insert (AliasBound name) typeOf (ArrayLit _ (Info t) _) = t -typeOf (StringLit vs _) = +typeOf (StringLit vs loc) = Array mempty Nonunique - (Shape [ConstSize $ genericLength vs]) + (Shape [sizeFromInteger (genericLength vs) loc]) (Prim (Unsigned Int8)) typeOf (Project _ _ (Info t) _) = t typeOf (Var _ (Info t) _) = t @@ -584,7 +609,7 @@ typeVars t = Scalar (Sum cs) -> mconcat $ (foldMap . fmap) typeVars cs Array _ _ _ rt -> typeVars $ Scalar rt where - typeArgFree (TypeArgType ta _) = typeVars ta + typeArgFree (TypeArgType ta) = typeVars ta typeArgFree TypeArgDim {} = mempty -- | @orderZero t@ is 'True' if the argument type has order 0, i.e., it is not @@ -710,13 +735,415 @@ intrinsicAcc = where acc_v = VName "acc" 10 t_v = VName "t" 11 - arg = TypeArgType (Scalar (TypeVar () Nonunique (qualName t_v) [])) mempty + arg = TypeArgType $ Scalar (TypeVar () Nonunique (qualName t_v) []) + +-- | If this type corresponds to the builtin "acc" type, return the +-- type of the underlying array. +isAccType :: TypeBase d as -> Maybe (TypeBase d ()) +isAccType (Scalar (TypeVar _ _ (QualName [] v) [TypeArgType t])) + | v == fst intrinsicAcc = + Just t +isAccType _ = Nothing + +-- | Find the 'VName' corresponding to a builtin. Crashes if that +-- name cannot be found. +intrinsicVar :: Name -> VName +intrinsicVar v = + fromMaybe bad $ find ((v ==) . baseName) $ M.keys intrinsics + where + bad = error $ "findBuiltin: " <> nameToString v -- | A map of all built-ins. intrinsics :: M.Map VName Intrinsic intrinsics = (M.fromList [intrinsicAcc] <>) $ M.fromList $ + primOp + ++ zipWith + namify + [intrinsicStart ..] + ( [ ( "flatten", + IntrinsicPolyFun + [tp_a, sp_n, sp_m] + [(Observe, Array () Nonunique (shape [n, m]) t_a)] + $ RetType [] + $ Array + () + Nonunique + ( Shape + [ SizeExpr + $ AppExp + ( BinOp + (findOp "*", mempty) + sizeBinOpInfo + (Var (qualName n) (Info i64) mempty, Info (i64, Nothing)) + (Var (qualName m) (Info i64) mempty, Info (i64, Nothing)) + mempty + ) + $ Info + $ AppRes i64 [] + ] + ) + t_a + ), + ( "unflatten", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, Scalar $ Prim $ Signed Int64), + (Observe, Scalar $ Prim $ Signed Int64), + (Observe, Array () Nonunique (shape [n]) t_a) + ] + $ RetType [k, m] + $ Array () Nonunique (shape [k, m]) t_a + ), + ( "concat", + IntrinsicPolyFun + [tp_a, sp_n, sp_m] + [ (Observe, array_a $ shape [n]), + (Observe, array_a $ shape [m]) + ] + $ RetType [] + $ uarray_a + $ Shape + [ SizeExpr + $ AppExp + ( BinOp + (findOp "+", mempty) + sizeBinOpInfo + (Var (qualName n) (Info i64) mempty, Info (i64, Nothing)) + (Var (qualName m) (Info i64) mempty, Info (i64, Nothing)) + mempty + ) + $ Info + $ AppRes i64 [] + ] + ), + ( "rotate", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, Scalar $ Prim $ Signed Int64), + (Observe, array_a $ shape [n]) + ] + $ RetType [] + $ array_a + $ shape [n] + ), + ( "transpose", + IntrinsicPolyFun + [tp_a, sp_n, sp_m] + [(Observe, array_a $ shape [n, m])] + $ RetType [] + $ array_a + $ shape [m, n] + ), + ( "scatter", + IntrinsicPolyFun + [tp_a, sp_n, sp_l] + [ (Consume, Array () Unique (shape [n]) t_a), + (Observe, Array () Nonunique (shape [l]) (Prim $ Signed Int64)), + (Observe, Array () Nonunique (shape [l]) t_a) + ] + $ RetType [] + $ Array () Unique (shape [n]) t_a + ), + ( "scatter_2d", + IntrinsicPolyFun + [tp_a, sp_n, sp_m, sp_l] + [ (Consume, uarray_a $ shape [n, m]), + (Observe, Array () Nonunique (shape [l]) (tupInt64 2)), + (Observe, Array () Nonunique (shape [l]) t_a) + ] + $ RetType [] + $ uarray_a + $ shape [n, m] + ), + ( "scatter_3d", + IntrinsicPolyFun + [tp_a, sp_n, sp_m, sp_k, sp_l] + [ (Consume, uarray_a $ shape [n, m, k]), + (Observe, Array () Nonunique (shape [l]) (tupInt64 3)), + (Observe, Array () Nonunique (shape [l]) t_a) + ] + $ RetType [] + $ uarray_a + $ shape [n, m, k] + ), + ( "zip", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [ (Observe, array_a (shape [n])), + (Observe, array_b (shape [n])) + ] + $ RetType [] + $ tuple_uarray (Scalar t_a) (Scalar t_b) + $ shape [n] + ), + ( "unzip", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [(Observe, tuple_arr (Scalar t_a) (Scalar t_b) $ shape [n])] + $ RetType [] . Scalar . Record . M.fromList + $ zip tupleFieldNames [array_a $ shape [n], array_b $ shape [n]] + ), + ( "hist_1d", + IntrinsicPolyFun + [tp_a, sp_n, sp_m] + [ (Consume, Scalar $ Prim $ Signed Int64), + (Observe, uarray_a $ shape [m]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 1)), + (Observe, array_a (shape [n])) + ] + $ RetType [] + $ uarray_a + $ shape [m] + ), + ( "hist_2d", + IntrinsicPolyFun + [tp_a, sp_n, sp_m, sp_k] + [ (Observe, Scalar $ Prim $ Signed Int64), + (Consume, uarray_a $ shape [m, k]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 2)), + (Observe, array_a (shape [n])) + ] + $ RetType [] + $ uarray_a + $ shape [m, k] + ), + ( "hist_3d", + IntrinsicPolyFun + [tp_a, sp_n, sp_m, sp_k, sp_l] + [ (Observe, Scalar $ Prim $ Signed Int64), + (Consume, uarray_a $ shape [m, k, l]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, Array () Nonunique (shape [n]) (tupInt64 3)), + (Observe, array_a (shape [n])) + ] + $ RetType [] + $ uarray_a + $ shape [m, k, l] + ), + ( "map", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, array_a $ shape [n]) + ] + $ RetType [] + $ uarray_b + $ shape [n] + ), + ( "reduce", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) + ] + $ RetType [] + $ Scalar t_a + ), + ( "reduce_comm", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) + ] + $ RetType [] + $ Scalar t_a + ), + ( "scan", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + (Observe, array_a $ shape [n]) + ] + $ RetType [] + $ uarray_a + $ shape [n] + ), + ( "partition", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, Scalar (Prim $ Signed Int32)), + (Observe, Scalar t_a `arr` Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [n]) + ] + ( RetType [m] . Scalar $ + tupleRecord + [ uarray_a $ shape [k], + Array () Unique (shape [n]) (Prim $ Signed Int64) + ] + ) + ), + ( "acc_write", + IntrinsicPolyFun + [sp_k, tp_a] + [ (Consume, Scalar $ accType array_ka), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar t_a) + ] + $ RetType [] + $ Scalar + $ accType array_ka + ), + ( "scatter_stream", + IntrinsicPolyFun + [tp_a, tp_b, sp_k, sp_n] + [ (Consume, uarray_ka), + ( Observe, + Scalar (accType array_ka) + `carr` ( Scalar t_b + `arr` Scalar (accType $ array_a $ shape [k]) + ) + ), + (Observe, array_b $ shape [n]) + ] + $ RetType [] uarray_ka + ), + ( "hist_stream", + IntrinsicPolyFun + [tp_a, tp_b, sp_k, sp_n] + [ (Consume, uarray_a $ shape [k]), + (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), + (Observe, Scalar t_a), + ( Observe, + Scalar (accType array_ka) + `carr` ( Scalar t_b + `arr` Scalar (accType $ array_a $ shape [k]) + ) + ), + (Observe, array_b $ shape [n]) + ] + $ RetType [] + $ uarray_a + $ shape [k] + ), + ( "jvp2", + IntrinsicPolyFun + [tp_a, tp_b] + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, Scalar t_a), + (Observe, Scalar t_a) + ] + $ RetType [] + $ Scalar + $ tupleRecord [Scalar t_b, Scalar t_b] + ), + ( "vjp2", + IntrinsicPolyFun + [tp_a, tp_b] + [ (Observe, Scalar t_a `arr` Scalar t_b), + (Observe, Scalar t_a), + (Observe, Scalar t_b) + ] + $ RetType [] + $ Scalar + $ tupleRecord [Scalar t_b, Scalar t_a] + ) + ] + ++ + -- Experimental LMAD ones. + [ ( "flat_index_2d", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) + ] + $ RetType [m, k] + $ array_a + $ shape [m, k] + ), + ( "flat_update_2d", + IntrinsicPolyFun + [tp_a, sp_n, sp_k, sp_l] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l]) + ] + $ RetType [] + $ uarray_a + $ shape [n] + ), + ( "flat_index_3d", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) + ] + $ RetType [m, k, l] + $ array_a + $ shape [m, k, l] + ), + ( "flat_update_3d", + IntrinsicPolyFun + [tp_a, sp_n, sp_k, sp_l, sp_p] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l, p]) + ] + $ RetType [] + $ uarray_a + $ shape [n] + ), + ( "flat_index_4d", + IntrinsicPolyFun + [tp_a, sp_n] + [ (Observe, array_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)) + ] + $ RetType [m, k, l, p] + $ array_a + $ shape [m, k, l, p] + ), + ( "flat_update_4d", + IntrinsicPolyFun + [tp_a, sp_n, sp_k, sp_l, sp_p, sp_q] + [ (Consume, uarray_a $ shape [n]), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, Scalar (Prim $ Signed Int64)), + (Observe, array_a $ shape [k, l, p, q]) + ] + $ RetType [] + $ uarray_a + $ shape [n] + ) + ] + ) + where + primOp = zipWith namify [20 ..] $ map primFun (M.toList Primitive.primFuns) ++ map unOpFun Primitive.allUnOps @@ -748,357 +1175,14 @@ intrinsics = -- The reason for the loop formulation is to ensure that we -- get a missing case warning if we forget a case. mapMaybe mkIntrinsicBinOp [minBound .. maxBound] - ++ [ ( "flatten", - IntrinsicPolyFun - [tp_a, sp_n, sp_m] - [(Observe, Array () Nonunique (shape [n, m]) t_a)] - $ RetType [k] - $ Array () Nonunique (shape [k]) t_a - ), - ( "unflatten", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, Scalar $ Prim $ Signed Int64), - (Observe, Scalar $ Prim $ Signed Int64), - (Observe, Array () Nonunique (shape [n]) t_a) - ] - $ RetType [k, m] - $ Array () Nonunique (shape [k, m]) t_a - ), - ( "concat", - IntrinsicPolyFun - [tp_a, sp_n, sp_m] - [ (Observe, array_a $ shape [n]), - (Observe, array_a $ shape [m]) - ] - $ RetType [k] - $ uarray_a - $ shape [k] - ), - ( "rotate", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, Scalar $ Prim $ Signed Int64), - (Observe, array_a $ shape [n]) - ] - $ RetType [] - $ array_a - $ shape [n] - ), - ( "transpose", - IntrinsicPolyFun - [tp_a, sp_n, sp_m] - [(Observe, array_a $ shape [n, m])] - $ RetType [] - $ array_a - $ shape [m, n] - ), - ( "scatter", - IntrinsicPolyFun - [tp_a, sp_n, sp_l] - [ (Consume, Array () Unique (shape [n]) t_a), - (Observe, Array () Nonunique (shape [l]) (Prim $ Signed Int64)), - (Observe, Array () Nonunique (shape [l]) t_a) - ] - $ RetType [] - $ Array () Unique (shape [n]) t_a - ), - ( "scatter_2d", - IntrinsicPolyFun - [tp_a, sp_n, sp_m, sp_l] - [ (Consume, uarray_a $ shape [n, m]), - (Observe, Array () Nonunique (shape [l]) (tupInt64 2)), - (Observe, Array () Nonunique (shape [l]) t_a) - ] - $ RetType [] - $ uarray_a - $ shape [n, m] - ), - ( "scatter_3d", - IntrinsicPolyFun - [tp_a, sp_n, sp_m, sp_k, sp_l] - [ (Consume, uarray_a $ shape [n, m, k]), - (Observe, Array () Nonunique (shape [l]) (tupInt64 3)), - (Observe, Array () Nonunique (shape [l]) t_a) - ] - $ RetType [] - $ uarray_a - $ shape [n, m, k] - ), - ( "zip", - IntrinsicPolyFun - [tp_a, tp_b, sp_n] - [ (Observe, array_a (shape [n])), - (Observe, array_b (shape [n])) - ] - $ RetType [] - $ tuple_uarray (Scalar t_a) (Scalar t_b) - $ shape [n] - ), - ( "unzip", - IntrinsicPolyFun - [tp_a, tp_b, sp_n] - [(Observe, tuple_arr (Scalar t_a) (Scalar t_b) $ shape [n])] - $ RetType [] . Scalar . Record . M.fromList - $ zip tupleFieldNames [array_a $ shape [n], array_b $ shape [n]] - ), - ( "hist_1d", - IntrinsicPolyFun - [tp_a, sp_n, sp_m] - [ (Consume, Scalar $ Prim $ Signed Int64), - (Observe, uarray_a $ shape [m]), - (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), - (Observe, Scalar t_a), - (Observe, Array () Nonunique (shape [n]) (tupInt64 1)), - (Observe, array_a (shape [n])) - ] - $ RetType [] - $ uarray_a - $ shape [m] - ), - ( "hist_2d", - IntrinsicPolyFun - [tp_a, sp_n, sp_m, sp_k] - [ (Observe, Scalar $ Prim $ Signed Int64), - (Consume, uarray_a $ shape [m, k]), - (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), - (Observe, Scalar t_a), - (Observe, Array () Nonunique (shape [n]) (tupInt64 2)), - (Observe, array_a (shape [n])) - ] - $ RetType [] - $ uarray_a - $ shape [m, k] - ), - ( "hist_3d", - IntrinsicPolyFun - [tp_a, sp_n, sp_m, sp_k, sp_l] - [ (Observe, Scalar $ Prim $ Signed Int64), - (Consume, uarray_a $ shape [m, k, l]), - (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), - (Observe, Scalar t_a), - (Observe, Array () Nonunique (shape [n]) (tupInt64 3)), - (Observe, array_a (shape [n])) - ] - $ RetType [] - $ uarray_a - $ shape [m, k, l] - ), - ( "map", - IntrinsicPolyFun - [tp_a, tp_b, sp_n] - [ (Observe, Scalar t_a `arr` Scalar t_b), - (Observe, array_a $ shape [n]) - ] - $ RetType [] - $ uarray_b - $ shape [n] - ), - ( "reduce", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), - (Observe, Scalar t_a), - (Observe, array_a $ shape [n]) - ] - $ RetType [] - $ Scalar t_a - ), - ( "reduce_comm", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), - (Observe, Scalar t_a), - (Observe, array_a $ shape [n]) - ] - $ RetType [] - $ Scalar t_a - ), - ( "scan", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), - (Observe, Scalar t_a), - (Observe, array_a $ shape [n]) - ] - $ RetType [] - $ uarray_a - $ shape [n] - ), - ( "partition", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, Scalar (Prim $ Signed Int32)), - (Observe, Scalar t_a `arr` Scalar (Prim $ Signed Int64)), - (Observe, array_a $ shape [n]) - ] - ( RetType [m] . Scalar $ - tupleRecord - [ uarray_a $ shape [k], - Array () Unique (shape [n]) (Prim $ Signed Int64) - ] - ) - ), - ( "acc_write", - IntrinsicPolyFun - [sp_k, tp_a] - [ (Consume, Scalar $ accType array_ka), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar t_a) - ] - $ RetType [] - $ Scalar - $ accType array_ka - ), - ( "scatter_stream", - IntrinsicPolyFun - [tp_a, tp_b, sp_k, sp_n] - [ (Consume, uarray_ka), - ( Observe, - Scalar (accType array_ka) - `carr` ( Scalar t_b - `arr` Scalar (accType $ array_a $ shape [k]) - ) - ), - (Observe, array_b $ shape [n]) - ] - $ RetType [] uarray_ka - ), - ( "hist_stream", - IntrinsicPolyFun - [tp_a, tp_b, sp_k, sp_n] - [ (Consume, uarray_a $ shape [k]), - (Observe, Scalar t_a `arr` (Scalar t_a `arr` Scalar t_a)), - (Observe, Scalar t_a), - ( Observe, - Scalar (accType array_ka) - `carr` ( Scalar t_b - `arr` Scalar (accType $ array_a $ shape [k]) - ) - ), - (Observe, array_b $ shape [n]) - ] - $ RetType [] - $ uarray_a - $ shape [k] - ), - ( "jvp2", - IntrinsicPolyFun - [tp_a, tp_b] - [ (Observe, Scalar t_a `arr` Scalar t_b), - (Observe, Scalar t_a), - (Observe, Scalar t_a) - ] - $ RetType [] - $ Scalar - $ tupleRecord [Scalar t_b, Scalar t_b] - ), - ( "vjp2", - IntrinsicPolyFun - [tp_a, tp_b] - [ (Observe, Scalar t_a `arr` Scalar t_b), - (Observe, Scalar t_a), - (Observe, Scalar t_b) - ] - $ RetType [] - $ Scalar - $ tupleRecord [Scalar t_b, Scalar t_a] - ) - ] - ++ - -- Experimental LMAD ones. - [ ( "flat_index_2d", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, array_a $ shape [n]), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)) - ] - $ RetType [m, k] - $ array_a - $ shape [m, k] - ), - ( "flat_update_2d", - IntrinsicPolyFun - [tp_a, sp_n, sp_k, sp_l] - [ (Consume, uarray_a $ shape [n]), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, array_a $ shape [k, l]) - ] - $ RetType [] - $ uarray_a - $ shape [n] - ), - ( "flat_index_3d", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, array_a $ shape [n]), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)) - ] - $ RetType [m, k, l] - $ array_a - $ shape [m, k, l] - ), - ( "flat_update_3d", - IntrinsicPolyFun - [tp_a, sp_n, sp_k, sp_l, sp_p] - [ (Consume, uarray_a $ shape [n]), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, array_a $ shape [k, l, p]) - ] - $ RetType [] - $ uarray_a - $ shape [n] - ), - ( "flat_index_4d", - IntrinsicPolyFun - [tp_a, sp_n] - [ (Observe, array_a $ shape [n]), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)) - ] - $ RetType [m, k, l, p] - $ array_a - $ shape [m, k, l, p] - ), - ( "flat_update_4d", - IntrinsicPolyFun - [tp_a, sp_n, sp_k, sp_l, sp_p, sp_q] - [ (Consume, uarray_a $ shape [n]), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, Scalar (Prim $ Signed Int64)), - (Observe, array_a $ shape [k, l, p, q]) - ] - $ RetType [] - $ uarray_a - $ shape [n] - ) - ] - where + + intrinsicStart = 1 + baseTag (fst $ last primOp) + + findOp op = + qualName $ maybe bad fst $ find ((op ==) . baseString . fst) primOp + where + bad = error $ "Intrinsics making, findOp: \"" <> op <> "\"" + [a, b, n, m, k, l, p, q] = zipWith VName (map nameFromString ["a", "b", "n", "m", "k", "l", "p", "q"]) [0 ..] t_a = TypeVar () Nonunique (qualName a) [] @@ -1113,7 +1197,14 @@ intrinsics = [sp_n, sp_m, sp_k, sp_l, sp_p, sp_q] = map (`TypeParamDim` mempty) [n, m, k, l, p, q] - shape = Shape . map (NamedSize . qualName) + shape = + Shape + . map + (flip sizeFromName mempty . qualName) + + i64 = Scalar $ Prim $ Signed Int64 + + sizeBinOpInfo = Info $ foldFunType [(Observe, i64), (Observe, i64)] $ RetType [] i64 tuple_arr x y s = Array @@ -1126,11 +1217,11 @@ intrinsics = arr x y = Scalar $ Arrow mempty Unnamed Observe x (RetType [] y) carr x y = Scalar $ Arrow mempty Unnamed Consume x (RetType [] y) - array_ka = Array () Nonunique (Shape [NamedSize $ qualName k]) t_a - uarray_ka = Array () Unique (Shape [NamedSize $ qualName k]) t_a + array_ka = Array () Nonunique (Shape [sizeFromName (qualName k) mempty]) t_a + uarray_ka = Array () Unique (Shape [sizeFromName (qualName k) mempty]) t_a accType t = - TypeVar () Unique (qualName (fst intrinsicAcc)) [TypeArgType t mempty] + TypeVar () Unique (qualName (fst intrinsicAcc)) [TypeArgType t] namify i (x, y) = (VName (nameFromString x) i, y) @@ -1371,6 +1462,154 @@ progHoles = foldMap holesInDec . progDecs pure e onExp e = astMap (identityMapper {mapOnExp = onExp}) e +-- | Strip semantically irrelevant stuff from the top level of +-- expression. This is used to provide a slightly fuzzy notion of +-- expression equality. +-- +-- Ideally we'd implement unification on a simpler representation that +-- simply didn't allow us. +stripExp :: Exp -> Maybe Exp +stripExp (Parens e _) = stripExp e `mplus` Just e +stripExp (Assert _ e _ _) = stripExp e `mplus` Just e +stripExp (Attr _ e _) = stripExp e `mplus` Just e +stripExp (Ascript e _ _) = stripExp e `mplus` Just e +stripExp _ = Nothing + +similarSlices :: Slice -> Slice -> Maybe [(Exp, Exp)] +similarSlices slice1 slice2 + | length slice1 == length slice2 = do + concat <$> zipWithM match slice1 slice2 + | otherwise = Nothing + where + match (DimFix e1) (DimFix e2) = Just [(e1, e2)] + match (DimSlice a1 b1 c1) (DimSlice a2 b2 c2) = + concat <$> sequence [pair (a1, a2), pair (b1, b2), pair (c1, c2)] + match _ _ = Nothing + pair (Nothing, Nothing) = Just [] + pair (Just x, Just y) = Just [(x, y)] + pair _ = Nothing + +-- | If these two expressions are structurally similar at top level as +-- sizes, produce their subexpressions (which are not necessarily +-- similar, but you can check for that!). This is the machinery +-- underlying expresssion unification. +similarExps :: Exp -> Exp -> Maybe [(Exp, Exp)] +similarExps e1 e2 | e1 == e2 = Just [] +similarExps e1 e2 | Just e1' <- stripExp e1 = similarExps e1' e2 +similarExps e1 e2 | Just e2' <- stripExp e2 = similarExps e1 e2' +similarExps + (AppExp (BinOp (op1, _) _ (x1, _) (y1, _) _) _) + (AppExp (BinOp (op2, _) _ (x2, _) (y2, _) _) _) + | op1 == op2 = Just [(x1, x2), (y1, y2)] +similarExps (AppExp (Apply f1 args1 _) _) (AppExp (Apply f2 args2 _) _) + | f1 == f2 = Just $ zip (map snd $ NE.toList args1) (map snd $ NE.toList args2) +similarExps (AppExp (Index arr1 slice1 _) _) (AppExp (Index arr2 slice2 _) _) + | arr1 == arr2, + length slice1 == length slice2 = + similarSlices slice1 slice2 +similarExps (TupLit es1 _) (TupLit es2 _) + | length es1 == length es2 = + Just $ zip es1 es2 +similarExps (RecordLit fs1 _) (RecordLit fs2 _) + | length fs1 == length fs2 = + zipWithM onFields fs1 fs2 + where + onFields (RecordFieldExplicit n1 fe1 _) (RecordFieldExplicit n2 fe2 _) + | n1 == n2 = Just (fe1, fe2) + onFields (RecordFieldImplicit vn1 ty1 _) (RecordFieldImplicit vn2 ty2 _) = + Just (Var (qualName vn1) ty1 mempty, Var (qualName vn2) ty2 mempty) + onFields _ _ = Nothing +similarExps (ArrayLit es1 _ _) (ArrayLit es2 _ _) + | length es1 == length es2 = + Just $ zip es1 es2 +similarExps (Project field1 e1 _ _) (Project field2 e2 _ _) + | field1 == field2 = + Just [(e1, e2)] +similarExps (Negate e1 _) (Negate e2 _) = + Just [(e1, e2)] +similarExps (Not e1 _) (Not e2 _) = + Just [(e1, e2)] +similarExps (Constr n1 es1 _ _) (Constr n2 es2 _ _) + | length es1 == length es2, + n1 == n2 = + Just $ zip es1 es2 +similarExps (Update e1 slice1 e'1 _) (Update e2 slice2 e'2 _) = + ([(e1, e2), (e'1, e'2)] ++) <$> similarSlices slice1 slice2 +similarExps (RecordUpdate e1 names1 e'1 _ _) (RecordUpdate e2 names2 e'2 _ _) + | names1 == names2 = + Just [(e1, e2), (e'1, e'2)] +similarExps (OpSection op1 _ _) (OpSection op2 _ _) + | op1 == op2 = Just [] +similarExps (OpSectionLeft op1 _ x1 _ _ _) (OpSectionLeft op2 _ x2 _ _ _) + | op1 == op2 = Just [(x1, x2)] +similarExps (OpSectionRight op1 _ x1 _ _ _) (OpSectionRight op2 _ x2 _ _ _) + | op1 == op2 = Just [(x1, x2)] +similarExps (ProjectSection names1 _ _) (ProjectSection names2 _ _) + | names1 == names2 = Just [] +similarExps (IndexSection slice1 _ _) (IndexSection slice2 _ _) = + similarSlices slice1 slice2 +similarExps _ _ = Nothing + +-- | An identifier with type- and aliasing information. +type Ident = IdentBase Info VName + +-- | An index with type information. +type DimIndex = DimIndexBase Info VName + +-- | A slice with type information. +type Slice = SliceBase Info VName + +-- | An expression with type information. +type Exp = ExpBase Info VName + +-- | An application expression with type information. +type AppExp = AppExpBase Info VName + +-- | A pattern with type information. +type Pat = PatBase Info VName + +-- | An constant declaration with type information. +type ValBind = ValBindBase Info VName + +-- | A type binding with type information. +type TypeBind = TypeBindBase Info VName + +-- | A type-checked module binding. +type ModBind = ModBindBase Info VName + +-- | A type-checked module type binding. +type SigBind = SigBindBase Info VName + +-- | A type-checked module expression. +type ModExp = ModExpBase Info VName + +-- | A type-checked module parameter. +type ModParam = ModParamBase Info VName + +-- | A type-checked module type expression. +type SigExp = SigExpBase Info VName + +-- | A type-checked declaration. +type Dec = DecBase Info VName + +-- | A type-checked specification. +type Spec = SpecBase Info VName + +-- | An Futhark program with type information. +type Prog = ProgBase Info VName + +-- | A known type arg with shape annotations. +type StructTypeArg = TypeArg Size + +-- | A type-checked type parameter. +type TypeParam = TypeParamBase VName + +-- | A known scalar type with no shape annotations. +type ScalarType = ScalarTypeBase () + +-- | A type-checked case (of a match expression). +type Case = CaseBase Info VName + -- | A type with no aliasing information but shape annotations. type UncheckedType = TypeBase (Shape Name) () diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index 09b22f6b6b..5b08d27cff 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -93,6 +93,10 @@ module Language.Futhark.Syntax QualName (..), mkApply, mkApplyUT, + sizeVar, + sizeInteger, + sizeFromName, + sizeFromInteger, ) where @@ -219,18 +223,36 @@ data AttrInfo vn -- | The elaborated size of a dimension. data Size - = -- | The size of the dimension is this name, which - -- must be in scope. In a return type, this will - -- give rise to an assertion. - NamedSize (QualName VName) - | -- | The size is a constant. - ConstSize Int64 + = -- | The size of the dimension is this expression + -- all non-trivial expression should have variable in scope. + -- In a return type, existential name don't appear in expression. + SizeExpr (ExpBase Info VName) | -- | No known size. If @Nothing@, then this is a name distinct -- from any other. The type checker should _never_ produce these -- - they are a (hopefully temporary) thing introduced by -- defunctorisation and monomorphisation. AnySize (Maybe VName) - deriving (Eq, Ord, Show) + deriving (Show, Eq, Ord) + +instance Located Size where + locOf (SizeExpr e) = locOf e + locOf AnySize {} = mempty + +-- | Create a 'Var' expression of type @i64@. +sizeVar :: QualName VName -> SrcLoc -> ExpBase Info VName +sizeVar name = Var name (Info $ Scalar $ Prim $ Signed Int64) + +-- | Create an 'IntLit' expression of type @i64@. +sizeInteger :: Integer -> SrcLoc -> ExpBase Info VName +sizeInteger x = IntLit x (Info <$> Scalar $ Prim $ Signed Int64) + +-- | Create a 'Size' with 'sizeVar'. +sizeFromName :: QualName VName -> SrcLoc -> Size +sizeFromName name loc = SizeExpr $ sizeVar name loc + +-- | Create a 'Size' with 'sizeInt'. +sizeFromInteger :: Integer -> SrcLoc -> Size +sizeFromInteger x loc = SizeExpr $ sizeInteger x loc -- | The size of an array type is a list of its dimension sizes. If -- 'Nothing', that dimension is of a (statically) unknown size. @@ -352,13 +374,13 @@ instance Bifoldable TypeBase where -- | An argument passed to a type constructor. data TypeArg dim - = TypeArgDim dim SrcLoc - | TypeArgType (TypeBase dim ()) SrcLoc + = TypeArgDim dim + | TypeArgType (TypeBase dim ()) deriving (Eq, Ord, Show) instance Traversable TypeArg where - traverse f (TypeArgDim v loc) = TypeArgDim <$> f v <*> pure loc - traverse f (TypeArgType t loc) = TypeArgType <$> bitraverse f pure t <*> pure loc + traverse f (TypeArgDim v) = TypeArgDim <$> f v + traverse f (TypeArgType t) = TypeArgType <$> bitraverse f pure t instance Functor TypeArg where fmap = fmapDefault @@ -654,7 +676,7 @@ data AppExpBase f vn -- constructing 'Apply' directly. -- -- The @Maybe VNames@ are existential sizes generated by this - -- argumnet. May have duplicates across the program, but they + -- argument. May have duplicates across the program, but they -- will all produce the same value (the expressions will be -- identical). Apply diff --git a/src/Language/Futhark/Traversals.hs b/src/Language/Futhark/Traversals.hs index 8706441ee8..1a59e2ac73 100644 --- a/src/Language/Futhark/Traversals.hs +++ b/src/Language/Futhark/Traversals.hs @@ -26,6 +26,7 @@ module Language.Futhark.Traversals ) where +import Data.Bifunctor import Data.List.NonEmpty qualified as NE import Data.Set qualified as S import Language.Futhark.Syntax @@ -270,8 +271,7 @@ instance ASTMappable (SizeExp Info VName) where astMap _ (SizeExpAny loc) = pure $ SizeExpAny loc instance ASTMappable Size where - astMap tv (NamedSize vn) = NamedSize <$> astMap tv vn - astMap _ (ConstSize k) = pure $ ConstSize k + astMap tv (SizeExpr expr) = SizeExpr <$> mapOnExp tv expr astMap tv (AnySize vn) = AnySize <$> traverse (mapOnName tv) vn instance ASTMappable (TypeParamBase VName) where @@ -332,10 +332,10 @@ traverseTypeArg :: (dim1 -> f dim2) -> TypeArg dim1 -> f (TypeArg dim2) -traverseTypeArg _ g (TypeArgDim d loc) = - TypeArgDim <$> g d <*> pure loc -traverseTypeArg f g (TypeArgType t loc) = - TypeArgType <$> traverseType f g pure t <*> pure loc +traverseTypeArg _ g (TypeArgDim d) = + TypeArgDim <$> g d +traverseTypeArg f g (TypeArgType t) = + TypeArgType <$> traverseType f g pure t instance ASTMappable StructType where astMap tv = traverseType (astMap tv) (astMap tv) pure @@ -443,8 +443,26 @@ bareLoopForm (While e) = While (bareExp e) bareCase :: CaseBase Info VName -> CaseBase NoInfo VName bareCase (CasePat pat e loc) = CasePat (barePat pat) (bareExp e) loc +bareSizeExp :: SizeExp Info VName -> SizeExp NoInfo VName +bareSizeExp (SizeExp e loc) = SizeExp (bareExp e) loc +bareSizeExp (SizeExpAny loc) = SizeExpAny loc + bareTypeExp :: TypeExp Info VName -> TypeExp NoInfo VName -bareTypeExp = undefined +bareTypeExp (TEVar qn loc) = TEVar qn loc +bareTypeExp (TEParens te loc) = TEParens (bareTypeExp te) loc +bareTypeExp (TETuple tys loc) = TETuple (map bareTypeExp tys) loc +bareTypeExp (TERecord fs loc) = TERecord (map (second bareTypeExp) fs) loc +bareTypeExp (TEArray size ty loc) = TEArray (bareSizeExp size) (bareTypeExp ty) loc +bareTypeExp (TEUnique ty loc) = TEUnique (bareTypeExp ty) loc +bareTypeExp (TEApply ty ta loc) = TEApply (bareTypeExp ty) (bareTypeArgExp ta) loc + where + bareTypeArgExp (TypeArgExpSize size) = + TypeArgExpSize $ bareSizeExp size + bareTypeArgExp (TypeArgExpType tya) = + TypeArgExpType $ bareTypeExp tya +bareTypeExp (TEArrow arg tya tyr loc) = TEArrow arg (bareTypeExp tya) (bareTypeExp tyr) loc +bareTypeExp (TESum cs loc) = TESum (map (second $ map bareTypeExp) cs) loc +bareTypeExp (TEDim names ty loc) = TEDim names (bareTypeExp ty) loc -- | Remove all annotations from an expression, but retain the -- name/scope information. diff --git a/src/Language/Futhark/TypeChecker.hs b/src/Language/Futhark/TypeChecker.hs index 4bfab35bac..101dab3bbc 100644 --- a/src/Language/Futhark/TypeChecker.hs +++ b/src/Language/Futhark/TypeChecker.hs @@ -51,7 +51,7 @@ checkProg :: UncheckedProg -> (Warnings, Either TypeError (FileModule, VNameSource)) checkProg files src name prog = - runTypeM initialEnv files' name src $ checkProgM prog + runTypeM initialEnv files' name src checkSizeExp $ checkProgM prog where files' = M.map fileEnv $ M.fromList files @@ -67,7 +67,7 @@ checkExp :: UncheckedExp -> (Warnings, Either TypeError ([TypeParam], Exp)) checkExp files src env e = - second (fmap fst) $ runTypeM env files' (mkInitialImport "") src $ checkOneExp e + second (fmap fst) $ runTypeM env files' (mkInitialImport "") src checkSizeExp $ checkOneExp e where files' = M.map fileEnv $ M.fromList files @@ -84,7 +84,7 @@ checkDec :: (Warnings, Either TypeError (Env, Dec, VNameSource)) checkDec files src env name d = second (fmap massage) $ - runTypeM env files' name src $ do + runTypeM env files' name src checkSizeExp $ do (_, env', d') <- checkOneDec d pure (env' <> env, d') where @@ -103,7 +103,7 @@ checkModExp :: ModExpBase NoInfo Name -> (Warnings, Either TypeError (MTy, ModExpBase Info VName)) checkModExp files src env me = - second (fmap fst) . runTypeM env files' (mkInitialImport "") src $ do + second (fmap fst) . runTypeM env files' (mkInitialImport "") src checkSizeExp $ do (_abs, mty, me') <- checkOneModExp me pure (mty, me') where @@ -564,7 +564,7 @@ checkTypeBind (TypeBind name l tps te NoInfo doc loc) = let elab_t = RetType (svars ++ dims) t - let used_dims = freeInType t + let used_dims = fvVars $ freeInType t case filter ((`S.notMember` used_dims) . typeParamName) $ filter isSizeParam tps' of [] -> pure () @@ -611,7 +611,7 @@ entryPoint params orig_ret_te (RetType ret orig_ret) = -- Since the entry point type is not a RetType but just a plain -- StructType, we have to remove any existentially bound sizes. - extToAny (NamedSize v) | qualLeaf v `elem` ret = AnySize Nothing + extToAny (SizeExpr (Var v _ _)) | qualLeaf v `elem` ret = AnySize Nothing extToAny d = d patternEntry (PatParens p _) = @@ -657,13 +657,19 @@ checkEntryPoint loc tparams params maybe_tdecl rettype "Entry point functions may not be higher-order." | sizes_only_in_ret <- S.fromList (map typeParamName tparams) - `S.intersection` freeInType rettype' - `S.difference` foldMap freeInType param_ts, + `S.intersection` fvVars (freeInType rettype') + `S.difference` foldMap (fvVars . freeInType) param_ts, not $ S.null sizes_only_in_ret = typeError loc mempty $ withIndexLink "size-polymorphic-entry" "Entry point functions must not be size-polymorphic in their return type." + | (constructive, _) <- foldMap determineSizeWitnesses param_ts, + Just p <- L.find (flip S.notMember constructive . typeParamName) tparams = + typeError p mempty . withIndexLink "nonconstructive-entry" $ + "Entry point size parameter " + <> pretty p + <> " only used non-constructively." | p : _ <- filter nastyParameter params = warn p $ "Entry point parameter\n" diff --git a/src/Language/Futhark/TypeChecker/Modules.hs b/src/Language/Futhark/TypeChecker/Modules.hs index 60d7202eb9..46973b7ed0 100644 --- a/src/Language/Futhark/TypeChecker/Modules.hs +++ b/src/Language/Futhark/TypeChecker/Modules.hs @@ -10,7 +10,6 @@ where import Control.Monad import Data.Either -import Data.List (intersect) import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord @@ -106,6 +105,10 @@ newNamesForMTy orig_mty = do substitute v = fromMaybe v $ M.lookup v substs + -- For applySubst and friends. + subst v = + ExpSubst . flip sizeVar mempty . qualName <$> M.lookup v substs + substituteInMap f m = let (ks, vs) = unzip $ M.toList m in M.fromList $ @@ -151,20 +154,16 @@ newNamesForMTy orig_mty = do substituteInType (Scalar (Arrow als v d1 t1 (RetType dims t2))) = Scalar $ Arrow als v d1 (substituteInType t1) $ RetType dims $ substituteInType t2 - substituteInShape (Shape ds) = - Shape $ map substituteInDim ds - substituteInDim (NamedSize (QualName qs v)) = - NamedSize $ QualName (map substitute qs) $ substitute v - substituteInDim d = d - - substituteInTypeArg (TypeArgDim (NamedSize (QualName qs v)) loc) = - TypeArgDim (NamedSize $ QualName (map substitute qs) $ substitute v) loc - substituteInTypeArg (TypeArgDim (ConstSize x) loc) = - TypeArgDim (ConstSize x) loc - substituteInTypeArg (TypeArgDim (AnySize v) loc) = - TypeArgDim (AnySize v) loc - substituteInTypeArg (TypeArgType t loc) = - TypeArgType (substituteInType t) loc + substituteInShape (Shape ds) = Shape $ map substituteInDim ds + substituteInDim (SizeExpr e) = SizeExpr $ applySubst subst e + substituteInDim AnySize {} = error "substituteInDim: AnySize" + + substituteInTypeArg (TypeArgDim (SizeExpr e)) = + TypeArgDim $ SizeExpr (applySubst subst e) + substituteInTypeArg (TypeArgDim AnySize {}) = + error "substituteInTypeArg: AnySize" + substituteInTypeArg (TypeArgType t) = + TypeArgType $ substituteInType t mtyTypeAbbrs :: MTy -> M.Map VName TypeBinding mtyTypeAbbrs (MTy _ mod) = modTypeAbbrs mod @@ -389,7 +388,7 @@ matchMTys :: Either TypeError (M.Map VName VName) matchMTys orig_mty orig_mty_sig = matchMTys' - (M.map (SizeSubst . NamedSize) $ resolveMTyNames orig_mty orig_mty_sig) + (M.map (ExpSubst . flip sizeVar mempty) $ resolveMTyNames orig_mty orig_mty_sig) [] orig_mty orig_mty_sig @@ -534,15 +533,17 @@ matchMTys orig_mty orig_mty_sig = -- if we have a value of an abstract type 't [n]', then there is -- an array of size 'n' somewhere inside. when (M.member spec_name abs_subst_to_type) $ - case S.toList (mustBeExplicitInType (retType t)) `intersect` map typeParamName ps of + case filter + (`S.notMember` fst (determineSizeWitnesses (retType t))) + (map typeParamName $ filter isSizeParam ps) of [] -> pure () d : _ -> Left . TypeError loc mempty $ "Type" indent 2 (ppTypeAbbr [] (QualName quals name) (l, ps, t)) textwrap "cannot be made abstract because size parameter" - dquotes (prettyName d) - textwrap "is not used as an array size in the definition." + indent 2 (prettyName d) + textwrap "is not used constructively as an array size in the definition." let spec_t' = applySubst (`M.lookup` abs_subst_to_type) spec_t nonrigid = ps <> map (`TypeParamDim` mempty) (retDims t) @@ -615,7 +616,7 @@ applyFunctor applyloc (FunSig p_abs p_mod body_mty) a_mty = do let a_abbrs = mtyTypeAbbrs a_mty isSub v = case M.lookup v a_abbrs of Just abbr -> Just $ substFromAbbr abbr - _ -> Just $ SizeSubst $ NamedSize $ qualName v + _ -> Just $ ExpSubst $ sizeVar (qualName v) mempty type_subst = M.mapMaybe isSub p_subst body_mty' = substituteTypesInMTy (`M.lookup` type_subst) body_mty (body_mty'', body_subst) <- newNamesForMTy body_mty' diff --git a/src/Language/Futhark/TypeChecker/Monad.hs b/src/Language/Futhark/TypeChecker/Monad.hs index c61ef0527a..e1c8a98037 100644 --- a/src/Language/Futhark/TypeChecker/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Monad.hs @@ -162,7 +162,8 @@ data Context = Context contextImportName :: ImportName, -- | Currently type-checking at the top level? If false, we are -- inside a module. - contextAtTopLevel :: Bool + contextAtTopLevel :: Bool, + contextCheckExp :: UncheckedExp -> TypeM Exp } data TypeState = TypeState @@ -205,10 +206,11 @@ runTypeM :: ImportTable -> ImportName -> VNameSource -> + (UncheckedExp -> TypeM Exp) -> TypeM a -> (Warnings, Either TypeError (a, VNameSource)) -runTypeM env imports fpath src (TypeM m) = do - let ctx = Context env imports fpath True +runTypeM env imports fpath src checker (TypeM m) = do + let ctx = Context env imports fpath True checker s = TypeState src mempty 0 case runExcept $ runStateT (runReaderT m ctx) s of Left (ws, e) -> (ws, Left e) @@ -286,17 +288,7 @@ class Monad m => MonadTypeChecker m where lookupMod :: SrcLoc -> QualName Name -> m (QualName VName, Mod) lookupVar :: SrcLoc -> QualName Name -> m (QualName VName, PatType) - checkNamedSize :: SrcLoc -> QualName Name -> m (QualName VName) - checkNamedSize loc v = do - (v', t) <- lookupVar loc v - case t of - Scalar (Prim (Signed Int64)) -> pure v' - _ -> - typeError loc mempty $ - "Sizes must have type i64, but" - <+> dquotes (pretty v) - <+> "has type:" - pretty t + checkExpForSize :: UncheckedExp -> m Exp typeError :: Located loc => loc -> Notes -> Doc () -> m a @@ -379,6 +371,10 @@ instance MonadTypeChecker TypeM where qualifyTypeVars outer_env mempty qs t' ) + checkExpForSize e = do + checker <- asks contextCheckExp + checker e + typeError loc notes s = throwError $ TypeError (locOf loc) notes s -- | Extract from a type a first-order type. @@ -440,12 +436,12 @@ qualifyTypeVars outer_env orig_except ref_qs = onType (S.fromList orig_except) Named p' -> S.insert p' except Unnamed -> except - onTypeArg except (TypeArgDim d loc) = - TypeArgDim (onDim except d) loc - onTypeArg except (TypeArgType t loc) = - TypeArgType (onType except t) loc + onTypeArg except (TypeArgDim d) = + TypeArgDim $ onDim except d + onTypeArg except (TypeArgType t) = + TypeArgType $ onType except t - onDim except (NamedSize qn) = NamedSize $ qual except qn + onDim except (SizeExpr (Var qn typ loc)) = SizeExpr $ Var (qual except qn) typ loc onDim _ d = d qual except (QualName orig_qs name) diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 7cf4d130a8..5586f73e75 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -8,6 +8,7 @@ -- will require the programmer to fall back on type annotations. module Language.Futhark.TypeChecker.Terms ( checkOneExp, + checkSizeExp, checkFunDef, ) where @@ -16,7 +17,9 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State +import Data.Bifunctor import Data.Bitraversable +import Data.Char (isAscii) import Data.Either import Data.List (find, foldl', genericLength, partition) import Data.List.NonEmpty qualified as NE @@ -37,6 +40,61 @@ import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) +hasBinding :: Exp -> Bool +hasBinding Literal {} = False +hasBinding IntLit {} = False +hasBinding FloatLit {} = False +hasBinding StringLit {} = False +hasBinding Hole {} = False +hasBinding Var {} = False +hasBinding (Parens e _) = hasBinding e +hasBinding (QualParens _ e _) = hasBinding e +hasBinding (TupLit es _) = any hasBinding es +hasBinding (RecordLit fs _) = any f fs + where + f (RecordFieldExplicit _ e _) = hasBinding e + f RecordFieldImplicit {} = False +hasBinding (ArrayLit es _ _) = any hasBinding es +hasBinding (Attr _ e _) = hasBinding e +hasBinding (Project _ e _ _) = hasBinding e +hasBinding (Negate e _) = hasBinding e +hasBinding (Not e _) = hasBinding e +hasBinding (Assert _ e _ _) = hasBinding e +hasBinding (Constr _ es _ _) = any hasBinding es +hasBinding (Update e1 slice e2 _) = hasBinding e1 || hasBinding e2 || any f slice + where + f (DimFix e) = hasBinding e + f (DimSlice me1 me2 me3) = any (maybe False hasBinding) [me1, me2, me3] +hasBinding (RecordUpdate e1 _ e2 _ _) = hasBinding e1 || hasBinding e2 +hasBinding Lambda {} = True +hasBinding OpSection {} = False +hasBinding (OpSectionLeft _ _ e _ _ _) = hasBinding e +hasBinding (OpSectionRight _ _ e _ _ _) = hasBinding e +hasBinding ProjectSection {} = False +hasBinding (IndexSection slice _ _) = any f slice + where + f (DimFix e) = hasBinding e + f (DimSlice me1 me2 me3) = any (maybe False hasBinding) [me1, me2, me3] +hasBinding (Ascript e _ _) = hasBinding e +hasBinding (AppExp (Apply f es _) _) = hasBinding f || any (hasBinding . snd) es +hasBinding (AppExp (Coerce e _ _) _) = hasBinding e +hasBinding (AppExp (Range ei es ef _) _) = hasBinding ei || maybe False hasBinding es || f ef + where + f (DownToExclusive e) = hasBinding e + f (ToInclusive e) = hasBinding e + f (UpToExclusive e) = hasBinding e +hasBinding (AppExp LetPat {} _) = True +hasBinding (AppExp LetFun {} _) = True +hasBinding (AppExp (If ec et ef _) _) = hasBinding ec || hasBinding et || hasBinding ef +hasBinding (AppExp DoLoop {} _) = True +hasBinding (AppExp (BinOp _ _ (el, _) (er, _) _) _) = hasBinding el || hasBinding er +hasBinding (AppExp LetWith {} _) = True +hasBinding (AppExp (Index e slice _) _) = hasBinding e || any f slice + where + f (DimFix e') = hasBinding e' + f (DimSlice me1 me2 me3) = any (maybe False hasBinding) [me1, me2, me3] +hasBinding (AppExp Match {} _) = True + overloadedTypeVars :: Constraints -> Names overloadedTypeVars = mconcat . map f . M.elems where @@ -62,7 +120,7 @@ unifyBranches loc e1 e2 = do sliceShape :: Maybe (SrcLoc, Rigidity) -> - Slice -> + [(DimIndex, Maybe Occurrence)] -> TypeBase Size as -> TermTypeM (TypeBase Size as, [VName]) sliceShape r slice t@(Array als u (Shape orig_dims) et) = @@ -88,34 +146,82 @@ sliceShape r slice t@(Array als u (Shape orig_dims) et) = modify (maybeToList ext ++) pure d Just (loc, Nonrigid) -> - lift $ NamedSize . qualName <$> newDimVar loc Nonrigid "slice_dim" + lift $ + flip sizeFromName loc . qualName + <$> newFlexibleDim (mkUsage loc "size of slice") "slice_dim" Nothing -> do v <- lift $ newID "slice_anydim" modify (v :) - pure $ NamedSize $ qualName v + pure $ sizeFromName (qualName v) mempty where -- The original size does not matter if the slice is fully specified. orig_d' | isJust i, isJust j = Nothing | otherwise = Just orig_d - adjustDims (DimFix {} : idxes') (_ : dims) = + warnIfConsumingOrBinding mOcc binds d i j stride size = + case (mOcc, binds) of + (Just occ, _) -> do + lift . warn (location occ) $ + withIndexLink + "size-expression-consume" + "Size expression with consumption is replaced by unknown size." + (:) <$> sliceSize d i j stride + (_, True) -> do + lift . warn (srclocOf size) $ + withIndexLink + "size-expression-bind" + "Size expression with binding is replaced by unknown size." + (:) <$> sliceSize d i j stride + (_, False) -> + pure (size :) + + adjustDims ((DimFix {}, _) : idxes') (_ : dims) = adjustDims idxes' dims -- Pat match some known slices to be non-existential. - adjustDims (DimSlice i j stride : idxes') (_ : dims) + adjustDims ((DimSlice i j stride, mOcc) : idxes') (d : dims) | refine_sizes, maybe True ((== Just 0) . isInt64) i, - Just j' <- maybeDimFromExp =<< j, maybe True ((== Just 1) . isInt64) stride = - (j' :) <$> adjustDims idxes' dims - adjustDims (DimSlice Nothing Nothing stride : idxes') (d : dims) + let j' = maybe d SizeExpr j + in warnIfConsumingOrBinding mOcc (maybe False hasBinding j) d i j stride j' <*> adjustDims idxes' dims + adjustDims ((DimSlice i j stride, mOcc) : idxes') (d : dims) | refine_sizes, - maybe True (maybe False ((== 1) . abs) . isInt64) stride = + Just i' <- i, -- if i ~ 0, previous case + maybe True ((== Just 1) . isInt64) stride = + let j' = fromMaybe (unSizeExpr d) j + in warnIfConsumingOrBinding mOcc (hasBinding j' || hasBinding i') d i j stride (sizeMinus j' i') <*> adjustDims idxes' dims + -- stride == -1 + adjustDims ((DimSlice Nothing Nothing stride, _) : idxes') (d : dims) + | refine_sizes, + maybe True ((== Just (-1)) . isInt64) stride = (d :) <$> adjustDims idxes' dims - adjustDims (DimSlice i j stride : idxes') (d : dims) = + adjustDims ((DimSlice (Just i) (Just j) stride, mOcc) : idxes') (d : dims) + | refine_sizes, + maybe True ((== Just (-1)) . isInt64) stride = + warnIfConsumingOrBinding mOcc (hasBinding i || hasBinding j) d (Just i) (Just j) stride (sizeMinus i j) <*> adjustDims idxes' dims + -- existential + adjustDims ((DimSlice i j stride, _) : idxes') (d : dims) = (:) <$> sliceSize d i j stride <*> adjustDims idxes' dims adjustDims _ dims = pure dims + + sizeMinus j i = + SizeExpr + $ AppExp + ( BinOp + (qualName (intrinsicVar "-"), mempty) + sizeBinOpInfo + (j, Info (i64, Nothing)) + (i, Info (i64, Nothing)) + mempty + ) + $ Info + $ AppRes i64 [] + i64 = Scalar $ Prim $ Signed Int64 + sizeBinOpInfo = Info $ foldFunType [(Observe, i64), (Observe, i64)] $ RetType [] i64 + unSizeExpr (SizeExpr e) = e + unSizeExpr AnySize {} = undefined sliceShape _ _ t = pure (t, []) --- Main checkers @@ -185,41 +291,129 @@ checkCoerce loc te e = do where makeNonExtFresh ext = bitraverse onDim pure where - onDim d@(NamedSize v) + onDim d@(SizeExpr (Var v _ _)) | qualLeaf v `elem` ext = pure d - onDim _ = do + onDim d = do v <- newTypeName "coerce" constrain v . Size Nothing $ mkUsage loc "a size coercion where the underlying expression size cannot be determined" - pure $ NamedSize $ qualName v + pure $ sizeFromName (qualName v) (srclocOf d) -unscopeType :: +sizeFree :: SrcLoc -> - M.Map VName Ident -> - PatType -> - TermTypeM (PatType, [VName]) -unscopeType tloc unscoped t = do - (t', m) <- runStateT (traverseDims onDim t) mempty - pure (t' `addAliases` S.map unAlias, M.elems m) + (Exp -> Maybe VName) -> + TypeBase Size as -> + TermTypeM (TypeBase Size as, [VName]) +sizeFree tloc expKiller orig_t = do + scope <- asks termScope + (orig_t', m) <- runStateT (onType (M.keysSet $ scopeVtable scope) orig_t) mempty + pure (orig_t', M.elems m) where - onDim bound _ (NamedSize d) - | Just loc <- srclocOf <$> M.lookup (qualLeaf d) unscoped, - not $ qualLeaf d `S.member` bound = - inst loc $ qualLeaf d - onDim _ _ d = pure d - - inst loc d = do - prev <- gets $ M.lookup d + -- using StateT (M.Map Exp VName) TermTypeM a + onScalar scope (Record fs) = + Record <$> traverse (onType scope) fs + onScalar scope (Sum cs) = + Sum <$> (traverse . traverse) (onType scope) cs + onScalar scope (Arrow as argName d argT (RetType dims retT)) = do + argT' <- onType scope argT + retT' <- onType (scope `S.union` argset) retT + rl <- + state . M.partitionWithKey $ + const . not . S.disjoint intros . fvVars . freeInExp . unSizeExpr + let dims' = dims <> M.elems rl + pure $ Arrow as argName d argT' (RetType dims' retT') + where + -- to check : completeness of the filter + intros = argset `S.difference` scope + argset = + fvVars (freeInType argT) + <> case argName of + Unnamed -> mempty + Named vn -> S.singleton vn + onScalar scope (TypeVar as u v args) = + TypeVar as u v <$> mapM onTypeArg args + where + onTypeArg (TypeArgDim d) = TypeArgDim <$> onSize d + onTypeArg (TypeArgType ty) = TypeArgType <$> onType scope ty + onScalar _ (Prim pt) = pure $ Prim pt + + unSizeExpr (SizeExpr e) = e + unSizeExpr AnySize {} = error "unSizeExpr: AnySize" + + onType :: + S.Set VName -> + TypeBase Size as -> + StateT (M.Map Size VName) TermTypeM (TypeBase Size as) -- Precise the typing, else haskell refuse it + onType scope (Array as u shape scalar) = + Array as u <$> traverse onSize shape <*> onScalar scope scalar + onType scope (Scalar ty) = + Scalar <$> onScalar scope ty + + onSize (SizeExpr e) = onExp e + onSize AnySize {} = error "onSize: AnySize" + + onExp e = do + let e' = SizeExpr e + prev <- gets $ M.lookup e' case prev of - Just d' -> pure $ NamedSize $ qualName d' + Just vn -> pure $ sizeFromName (qualName vn) (srclocOf e) Nothing -> do - d' <- lift $ newDimVar tloc (Rigid $ RigidOutOfScope loc d) "d" - modify $ M.insert d d' - pure $ NamedSize $ qualName d' + case expKiller e of + Nothing -> pure $ SizeExpr e + Just cause -> do + vn <- lift $ newRigidDim tloc (RigidOutOfScope (srclocOf e) cause) "d" + modify $ M.insert (SizeExpr e) vn + pure $ sizeFromName (qualName vn) (srclocOf e) + +-- Used to remove unknowable sizes from function body types before we +-- perform let-generalisation. This is because if a function is +-- inferred to return something of type '[x+y]t' where 'x' or 'y' are +-- unknowable, we want to turn that into '[z]t', where ''z' is a fresh +-- unknowable, which is then by let-generalisation turned into +-- '?[z].[z]t'. +unscopeUnknowable :: + TypeBase Size as -> + TermTypeM (TypeBase Size as) +unscopeUnknowable t = do + constraints <- getConstraints + -- These sizes will be immediately turned into existentials, so we + -- do not need to care about their location. + fst <$> sizeFree mempty (expKiller constraints) t + where + expKiller _ Var {} = Nothing + expKiller constraints e = + S.lookupMin $ S.filter (isUnknown constraints) $ fvVars $ freeInExp e + isUnknown constraints vn + | Just UnknowableSize {} <- snd <$> M.lookup vn constraints = True + isUnknown _ _ = False + +unscopeTypeBase :: + SrcLoc -> + S.Set VName -> + TypeBase Size as -> + TermTypeM (TypeBase Size as, [VName]) +unscopeTypeBase tloc unscoped = + sizeFree tloc $ S.lookupMin . S.intersection unscoped . fvVars . freeInExp - unAlias (AliasBound v) | v `M.member` unscoped = AliasFree v +unscopeStructType :: + SrcLoc -> + S.Set VName -> + StructType -> + TermTypeM (StructType, [VName]) +unscopeStructType = unscopeTypeBase + +unscopePatType :: + SrcLoc -> + S.Set VName -> + PatType -> + TermTypeM (PatType, [VName]) +unscopePatType tloc unscoped t = do + (t', m) <- unscopeTypeBase tloc unscoped t + pure (t' `addAliases` S.map unAlias, m) + where + unAlias (AliasBound v) | v `S.member` unscoped = AliasFree v unAlias a = a checkExp :: UncheckedExp -> TermTypeM Exp @@ -274,17 +468,17 @@ checkExp (ArrayLit all_es _ loc) = case all_es of [] -> do et <- newTypeVar loc "t" - t <- arrayOfM loc et (Shape [ConstSize 0]) Nonunique + t <- arrayOfM loc et (Shape [sizeFromInteger 0 mempty]) Nonunique pure $ ArrayLit [] (Info t) loc e : es -> do e' <- checkExp e et <- expType e' es' <- mapM (unifies "type of first array element" (toStruct et) <=< checkExp) es et' <- normTypeFully et - t <- arrayOfM loc et' (Shape [ConstSize $ genericLength all_es]) Nonunique + t <- arrayOfM loc et' (Shape [sizeFromInteger (genericLength all_es) mempty]) Nonunique pure $ ArrayLit (e' : es') (Info t) loc checkExp (AppExp (Range start maybe_step end loc) _) = do - start' <- require "use in range expression" anySignedType =<< checkExp start + (start', startOcc) <- tapOccurrences $ require "use in range expression" anySignedType =<< checkExp start start_t <- toStruct <$> expTypeFully start' maybe_step' <- case maybe_step of Nothing -> pure Nothing @@ -297,7 +491,7 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do Just <$> (unifies "use in range expression" start_t =<< checkExp step) let unifyRange e = unifies "use in range expression" start_t =<< checkExp e - end' <- traverse unifyRange end + (end', endOcc) <- tapOccurrences $ traverse unifyRange end end_t <- case end' of DownToExclusive e -> expType e @@ -305,26 +499,61 @@ checkExp (AppExp (Range start maybe_step end loc) _) = do UpToExclusive e -> expType e -- Special case some ranges to give them a known size. - let dimFromBound = dimFromExp (SourceBound . bareExp) + let warnIfConsumingOrBinding binds size = + case (anyConsumption (startOcc <> endOcc), binds) of + (Just occ, _) -> do + warn (location occ) $ + withIndexLink + "size-expression-consume" + "Size expression with consumption is replaced by unknown size." + d <- newRigidDim loc RigidRange "range_dim" + pure (sizeFromName (qualName d) mempty, Just d) + (_, True) -> do + warn (srclocOf size) $ + withIndexLink + "size-expression-bind" + "Size expression with binding is replaced by unknown size." + d <- newRigidDim loc RigidRange "range_dim" + pure (sizeFromName (qualName d) mempty, Just d) + (_, False) -> + pure (size, Nothing) (dim, retext) <- case (isInt64 start', isInt64 <$> maybe_step', end') of (Just 0, Just (Just 1), UpToExclusive end'') | Scalar (Prim (Signed Int64)) <- end_t -> - dimFromBound end'' + warnIfConsumingOrBinding (hasBinding end'') $ SizeExpr end'' (Just 0, Nothing, UpToExclusive end'') | Scalar (Prim (Signed Int64)) <- end_t -> - dimFromBound end'' + warnIfConsumingOrBinding (hasBinding end'') $ SizeExpr end'' + (_, Nothing, UpToExclusive end'') + | Scalar (Prim (Signed Int64)) <- end_t -> + warnIfConsumingOrBinding (hasBinding end'' || hasBinding start') $ sizeMinus end'' start' (Just 1, Just (Just 2), ToInclusive end'') | Scalar (Prim (Signed Int64)) <- end_t -> - dimFromBound end'' + warnIfConsumingOrBinding (hasBinding end'') $ SizeExpr end'' _ -> do - d <- newDimVar loc (Rigid RigidRange) "range_dim" - pure (NamedSize $ qualName d, Just d) + d <- newRigidDim loc RigidRange "range_dim" + pure (sizeFromName (qualName d) mempty, Just d) t <- arrayOfM loc start_t (Shape [dim]) Nonunique let res = AppRes (t `setAliases` mempty) (maybeToList retext) pure $ AppExp (Range start' maybe_step' end' loc) (Info res) + where + sizeMinus j i = + SizeExpr + $ AppExp + ( BinOp + (qualName (intrinsicVar "-"), mempty) + sizeBinOpInfo + (j, Info (i64, Nothing)) + (i, Info (i64, Nothing)) + mempty + ) + $ Info + $ AppRes i64 [] + i64 = Scalar $ Prim $ Signed Int64 + sizeBinOpInfo = Info $ foldFunType [(Observe, i64), (Observe, i64)] $ RetType [] i64 checkExp (Ascript e te loc) = do (te', e') <- checkAscript loc te e pure $ Ascript e' te' loc @@ -377,7 +606,7 @@ checkExp (AppExp (If e1 e2 e3 loc) _) = let bool = Scalar $ Prim Bool e1_t <- toStruct <$> expType e1' onFailure (CheckingRequired [bool] e1_t) $ - unify (mkUsage (srclocOf e1') "use as 'if' condition") bool e1_t + unify (mkUsage e1' "use as 'if' condition") bool e1_t pure e1' checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc @@ -439,6 +668,7 @@ checkExp (AppExp (Apply fe args loc) NoInfo) = do Var v _ _ -> Just v _ -> Nothing ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' + pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts where onArg fname (i, all_exts, t) arg' = do @@ -460,18 +690,20 @@ checkExp (AppExp (LetPat sizes pat e body loc) _) = (toStruct t) _ -> pure () + constraints <- getConstraints incLevel . bindingSizes sizes $ \sizes' -> bindingPat sizes' pat (Ascribed t) $ \pat' -> do + let implicitSizes = S.filter (isUnknown constraints) $ fvVars $ freeInPat pat' body' <- checkExp body (body_t, retext) <- - unscopeType loc (sizesMap sizes' <> patternMap pat') =<< expTypeFully body' + unscopePatType loc (sizesMap sizes' <> patNames pat' <> implicitSizes) =<< expTypeFully body' pure $ AppExp (LetPat sizes' pat' e' body' loc) (Info $ AppRes body_t retext) where - sizesMap = foldMap onSize - onSize size = - M.singleton (sizeName size) $ - Ident (sizeName size) (Info (Scalar $ Prim $ Signed Int64)) (srclocOf size) + sizesMap = foldMap (S.singleton . sizeName) + isUnknown constraints vn + | Just UnknowableSize {} <- snd <$> M.lookup vn constraints = True + isUnknown _ _ = False checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = sequentially (checkBinding (name, maybe_retdecl, tparams, params, e, loc)) $ \(tparams', params', maybe_retdecl', rettype, e') closure -> do @@ -492,11 +724,8 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l } body' <- localScope bindF $ checkExp body - -- We fake an ident here, but it's OK as it can't be a size - -- anyway. - let fake_ident = Ident name' (Info $ fromStruct ftype) mempty (body_t, ext) <- - unscopeType loc (M.singleton name' fake_ident) + unscopePatType loc (S.singleton name') =<< expTypeFully body' pure $ @@ -510,7 +739,7 @@ checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body l (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = sequentially ((,) <$> checkIdent src <*> checkSlice slice) $ \(src', slice') _ -> do - (t, _) <- newArrayType (srclocOf src) "src" $ sliceDims slice' + (t, _) <- newArrayType (mkUsage src "type of source array") "src" $ sliceDims slice' unify (mkUsage loc "type of target array") t $ toStruct $ unInfo $ identType src' -- Need the fully normalised type here to get the proper aliasing information. @@ -526,12 +755,12 @@ checkExp (AppExp (LetWith dest src slice ve body loc) _) = bindingIdent dest (src_t `setAliases` S.empty) $ \dest' -> do body' <- consuming src' $ checkExp body (body_t, ext) <- - unscopeType loc (M.singleton (identName dest') dest') + unscopePatType loc (S.singleton (identName dest')) =<< expTypeFully body' - pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t ext) + pure $ AppExp (LetWith dest' src' (map fst slice') ve' body' loc) (Info $ AppRes body_t ext) checkExp (Update src slice ve loc) = do slice' <- checkSlice slice - (t, _) <- newArrayType (srclocOf src) "src" $ sliceDims slice' + (t, _) <- newArrayType (mkUsage' src) "src" $ sliceDims slice' (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t sequentially (checkExp ve >>= unifies "type of target array" elemt) $ \ve' _ -> @@ -543,7 +772,7 @@ checkExp (Update src slice ve loc) = do unless (S.null $ src_als `S.intersection` aliases ve_t) $ badLetWithValue src ve loc consume loc src_als - pure $ Update src' slice' ve' loc + pure $ Update src' (map fst slice') ve' loc -- Record updates are a bit hacky, because we do not have row typing -- (yet?). For now, we only permit record updates where we know the @@ -559,7 +788,7 @@ checkExp (RecordUpdate src fields ve NoInfo loc) = do where usage = mkUsage loc "record update" updateField [] ve_t src_t = do - (src_t', _) <- allDimsFreshInType loc Nonrigid "any" src_t + (src_t', _) <- allDimsFreshInType usage Nonrigid "any" src_t onFailure (CheckingRecordUpdate fields (toStruct src_t') (toStruct ve_t)) $ unify usage (toStruct src_t') (toStruct ve_t) -- Important that we return ve_t so that we get the right aliases. @@ -577,7 +806,7 @@ checkExp (RecordUpdate src fields ve NoInfo loc) = do -- checkExp (AppExp (Index e slice loc) _) = do slice' <- checkSlice slice - (t, _) <- newArrayType loc "e" $ sliceDims slice' + (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' e' <- unifies "being indexed at" t =<< checkExp e -- XXX, the RigidSlice here will be overridden in sliceShape with a proper value. (t', retext) <- @@ -588,13 +817,13 @@ checkExp (AppExp (Index e slice loc) _) = do -- will certainly not be aliased. t'' <- noAliasesIfOverloaded t' - pure $ AppExp (Index e' slice' loc) (Info $ AppRes t'' retext) + pure $ AppExp (Index e' (map fst slice') loc) (Info $ AppRes t'' retext) checkExp (Assert e1 e2 NoInfo loc) = do e1' <- require "being asserted" [Bool] =<< checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc checkExp (Lambda params body rettype_te NoInfo loc) = do - (params', body', body_t, rettype', info) <- + (params', body', body_t, rettype', Info (as, RetType dims ty)) <- removeSeminullOccurrences . noUnique . incLevel . bindingParams [] params $ \_ params' -> do rettype_checked <- traverse checkTypeExpNonrigid rettype_te let declared_rettype = @@ -626,7 +855,9 @@ checkExp (Lambda params body rettype_te NoInfo loc) = do checkGlobalAliases params' body_t loc verifyFunctionParams Nothing params' - pure $ Lambda params' body' rettype' info loc + (ty', dims') <- unscopeStructType loc (S.fromList dims) ty + + pure $ Lambda params' body' rettype' (Info (as, RetType dims' ty')) loc where -- Inferring the sizes of the return type of a lambda is a lot -- like let-generalisation. We wish to remove any rigid sizes @@ -650,7 +881,7 @@ checkExp (Lambda params body rettype_te NoInfo loc) = do | name `S.member` hidden_sizes = S.singleton name onDim _ = mempty - pure $ RetType (S.toList $ foldMap onDim $ freeInType ret) ret + pure $ RetType (S.toList $ foldMap onDim $ fvVars $ freeInType ret) ret checkExp (OpSection op _ loc) = do (op', ftype) <- lookupVar loc op pure $ OpSection op' (Info ftype) loc @@ -676,20 +907,23 @@ checkExp (OpSectionRight op _ e _ NoInfo loc) = do e_arg <- checkArg e case ftype of Scalar (Arrow as1 m1 d1 t1 (RetType [] (Scalar (Arrow as2 m2 d2 t2 (RetType dims2 ret))))) -> do - (_, t2', ret', argext, _) <- + (_, t2', arrow', argext, _) <- checkApply loc (Just op', 1) - (Scalar $ Arrow as2 m2 d2 t2 $ RetType [] $ Scalar $ Arrow as1 m1 d1 t1 $ RetType [] ret) + (Scalar $ Arrow as2 m2 d2 t2 $ RetType [] $ Scalar $ Arrow as1 m1 d1 t1 $ RetType dims2 ret) e_arg - pure $ - OpSectionRight - op' - (Info ftype) - (argExp e_arg) - (Info (m1, toStruct t1), Info (m2, toStruct t2', argext)) - (Info $ RetType dims2 $ addAliases ret (<> aliases ret')) - loc + case arrow' of + Scalar (Arrow _ _ _ t1' (RetType dims2' ret')) -> + pure $ + OpSectionRight + op' + (Info ftype) + (argExp e_arg) + (Info (m1, toStruct t1'), Info (m2, toStruct t2', argext)) + (Info $ RetType dims2' $ addAliases ret' (<> aliases arrow')) + loc + _ -> undefined _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype @@ -701,10 +935,10 @@ checkExp (ProjectSection fields NoInfo loc) = do pure $ ProjectSection fields (Info ft) loc checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice - (t, _) <- newArrayType loc "e" $ sliceDims slice' + (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' (t', retext) <- sliceShape Nothing slice' t let ft = Scalar $ Arrow mempty Unnamed Observe t $ RetType retext $ fromStruct t' - pure $ IndexSection slice' (Info ft) loc + pure $ IndexSection (map fst slice') (Info ft) loc checkExp (AppExp (DoLoop _ mergepat mergeexp form loopbody loc) _) = do ((sparams, mergepat', mergeexp', form', loopbody'), appres) <- checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc @@ -757,9 +991,15 @@ checkCase :: TermTypeM (CaseBase Info VName, PatType, [VName]) checkCase mt (CasePat p e loc) = bindingPat [] p (Ascribed mt) $ \p' -> do + constraints <- getConstraints + let implicitSizes = S.filter (isUnknown constraints) $ fvVars $ freeInPat p' e' <- checkExp e - (t, retext) <- unscopeType loc (patternMap p') =<< expTypeFully e' + (t, retext) <- unscopePatType loc (patNames p' <> implicitSizes) =<< expTypeFully e' pure (CasePat p' e' loc, t, retext) + where + isUnknown constraints vn + | Just UnknowableSize {} <- snd <$> M.lookup vn constraints = True + isUnknown _ _ = False -- | An unmatched pattern. Used in in the generation of -- unmatched pattern warnings by the type checker. @@ -794,13 +1034,15 @@ checkIdent (Ident name _ loc) = do (QualName _ name', vt) <- lookupVar loc (qualName name) pure $ Ident name' (Info vt) loc -checkSlice :: UncheckedSlice -> TermTypeM Slice +checkSlice :: UncheckedSlice -> TermTypeM [(DimIndex, Maybe Occurrence)] checkSlice = mapM checkDimIndex where - checkDimIndex (DimFix i) = - DimFix <$> (require "use as index" anySignedType =<< checkExp i) - checkDimIndex (DimSlice i j s) = - DimSlice <$> check i <*> check j <*> check s + checkDimIndex (DimFix i) = do + (i', dflow) <- tapOccurrences (require "use as index" anySignedType =<< checkExp i) + pure (DimFix i', anyConsumption dflow) + checkDimIndex (DimSlice i j s) = do + (sl, dflow) <- tapOccurrences $ DimSlice <$> check i <*> check j <*> check s + pure (sl, anyConsumption dflow) check = maybe (pure Nothing) $ @@ -808,7 +1050,7 @@ checkSlice = mapM checkDimIndex -- The number of dimensions affected by this slice (so the minimum -- rank of the array we are slicing). -sliceDims :: Slice -> Int +sliceDims :: [(DimIndex, Maybe Occurrence)] -> Int sliceDims = length type Arg = (Exp, PatType, Occurrences, SrcLoc) @@ -830,8 +1072,18 @@ instantiateDimsInReturnType :: Maybe (QualName VName) -> RetTypeBase Size als -> TermTypeM (TypeBase Size als, [VName]) -instantiateDimsInReturnType tloc fname = - instantiateEmptyArrayDims tloc $ Rigid $ RigidRet fname +instantiateDimsInReturnType loc fname (RetType dims t) = do + dims' <- mapM new dims + pure (first (onDim $ zip dims dims') t, dims') + where + new = + newRigidDim loc (RigidRet fname) + . nameFromString + . takeWhile isAscii + . baseString + onDim dims' (SizeExpr (Var d _ _)) = + sizeFromName (maybe d qualName (lookup (qualLeaf d) dims')) loc + onDim _ d = d -- Some information about the function/operator we are trying to -- apply, and how many arguments it has previously accepted. Used for @@ -844,7 +1096,7 @@ boundInsideType (Array _ _ _ t) = boundInsideType (Scalar t) boundInsideType (Scalar Prim {}) = mempty boundInsideType (Scalar (TypeVar _ _ _ targs)) = foldMap f targs where - f (TypeArgType t _) = boundInsideType t + f (TypeArgType t) = boundInsideType t f TypeArgDim {} = mempty boundInsideType (Scalar (Record fs)) = foldMap boundInsideType fs boundInsideType (Scalar (Sum cs)) = foldMap (foldMap boundInsideType) cs @@ -860,9 +1112,9 @@ boundInsideType (Scalar (Arrow _ pn _ t1 (RetType dims t2))) = dimUses :: StructType -> (Names, Names) dimUses = flip execState mempty . traverseDims f where - f bound _ (NamedSize v) | qualLeaf v `S.member` bound = pure () - f _ PosImmediate (NamedSize v) = modify ((S.singleton (qualLeaf v), mempty) <>) - f _ PosParam (NamedSize v) = modify ((mempty, S.singleton (qualLeaf v)) <>) + f bound _ (SizeExpr (Var v _ _)) | qualLeaf v `S.member` bound = pure () + f _ PosImmediate (SizeExpr (Var v _ _)) = modify ((S.singleton (qualLeaf v), mempty) <>) + f _ PosParam (SizeExpr (Var v _ _)) = modify ((mempty, S.singleton (qualLeaf v)) <>) f _ _ _ = pure () checkApply :: @@ -877,7 +1129,7 @@ checkApply (Scalar (Arrow as pname d1 tp1 tp2)) (argexp, argtype, dflow, argloc) = onFailure (CheckingApply fname argexp tp1 (toStruct argtype)) $ do - expect (mkUsage argloc "use as function argument") (toStruct tp1) (toStruct argtype) + unify (mkUsage argloc "use as function argument") (toStruct tp1) (toStruct argtype) -- Perform substitutions of instantiated variables in the types. tp1' <- normTypeFully tp1 @@ -915,12 +1167,33 @@ checkApply (argext, parsubst) <- case pname of Named pname' - | (Scalar (Prim (Signed Int64))) <- tp1' -> do - (d, argext) <- sizeFromArg fname argexp - pure - ( argext, - (`M.lookup` M.singleton pname' (SizeSubst d)) - ) + | M.member pname' (unFV $ freeInType tp2') -> + case (isJust (anyConsumption dflow), hasBinding argexp) of + (True, _) -> do + warn (srclocOf argexp) $ + withIndexLink + "size-expression-consume" + "Size expression with consumption is replaced by unknown size." + d <- newRigidDim argexp (RigidArg fname $ prettyTextOneLine $ bareExp argexp) "n" + pure + ( Just d, + (`M.lookup` M.singleton pname' (ExpSubst $ sizeVar (qualName d) $ srclocOf argexp)) + ) + (_, True) -> do + warn (srclocOf argexp) $ + withIndexLink + "size-expression-bind" + "Size expression with binding is replaced by unknown size." + d <- newRigidDim argexp (RigidArg fname $ prettyTextOneLine $ bareExp argexp) "n" + pure + ( Just d, + (`M.lookup` M.singleton pname' (ExpSubst $ sizeVar (qualName d) $ srclocOf argexp)) + ) + (False, False) -> + pure + ( Nothing, + (`M.lookup` M.singleton pname' (ExpSubst argexp)) + ) _ -> pure (Nothing, const Nothing) -- In case a function result is not immediately bound to a name, @@ -988,7 +1261,7 @@ consumedByArg _ _ _ = pure [] -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: UncheckedExp -> TypeM ([TypeParam], Exp) -checkOneExp e = fmap fst . runTermTypeM $ do +checkOneExp e = fmap fst . runTermTypeM checkExp $ do e' <- checkExp e let t = toStruct $ typeOf e' (tparams, _, _) <- @@ -999,6 +1272,18 @@ checkOneExp e = fmap fst . runTermTypeM $ do causalityCheck e'' pure (tparams, e'') +-- | Type-check a single size expression in isolation. This expression may +-- turn out to be polymorphic, in which case it is unified with i64. +checkSizeExp :: UncheckedExp -> TypeM Exp +checkSizeExp e = fmap fst . runTermTypeM checkExp $ do + e' <- noUnique $ checkExp e + let t = toStruct $ typeOf e' + when (hasBinding e') $ + typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ + "Size expression with binding is forbidden." + unify (mkUsage e' "Size expression") t (Scalar (Prim (Signed Int64))) + updateTypes e' + -- Verify that all sum type constructors and empty array literals have -- a size that is known (rigid or a type parameter). This is to -- ensure that we can actually determine their shape at run-time. @@ -1009,9 +1294,10 @@ causalityCheck binding_body = do let checkCausality what known t loc | (d, dloc) : _ <- mapMaybe (unknown constraints known) $ - S.toList $ - freeInType $ - toStruct t = + M.keys $ + unFV $ + freeInType $ + toStruct t = Just $ lift $ causality what (locOf loc) d dloc t | otherwise = Nothing @@ -1050,6 +1336,11 @@ causalityCheck binding_body = do onExp known e@(AppExp (LetPat _ _ bindee_e body_e _) (Info res)) = do sequencePoint known bindee_e body_e $ appResExt res pure e + onExp known e@(AppExp (Match scrutinee cs _) (Info res)) = do + new_known <- lift $ execStateT (onExp known scrutinee) mempty + void $ recurse (new_known <> known) cs + modify ((new_known <> S.fromList (appResExt res)) <>) + pure e onExp known e@(AppExp (Apply f args _) (Info res)) = do seqArgs known $ reverse $ NE.toList args pure e @@ -1104,7 +1395,7 @@ causalityCheck binding_body = do Left . TypeError loc mempty . withIndexLink "causality-check" $ "Causality check: size" dquotes (prettyName d) - "needed for type of" + <+> "needed for type of" <+> what <> colon indent 2 (pretty t) "But" @@ -1200,7 +1491,7 @@ checkFunDef :: Exp ) checkFunDef (fname, maybe_retdecl, tparams, params, body, loc) = - fmap fst . runTermTypeM $ do + fmap fst . runTermTypeM checkExp $ do (tparams', params', maybe_retdecl', RetType dims rettype', body') <- checkBinding (fname, maybe_retdecl, tparams, params, body, loc) @@ -1303,10 +1594,10 @@ inferredReturnType loc params t = do -- These we must turn into fresh type variables, which will be -- existential in the return type. fmap (toStruct . fst) $ - unscopeType loc hidden_params $ + unscopePatType loc hidden_params $ inferReturnUniqueness params t where - hidden_params = M.filterWithKey (const . (`S.member` hidden)) $ foldMap patternMap params + hidden_params = M.keysSet $ M.filterWithKey (const . (`S.member` hidden)) $ foldMap patternMap params hidden = hiddenParamNames params checkReturnAlias :: SrcLoc -> TypeBase () () -> [Pat] -> PatType -> TermTypeM () @@ -1397,7 +1688,8 @@ checkBinding (fname, maybe_retdecl, tparams, params, body, loc) = verifyFunctionParams (Just fname) params'' (tparams'', params''', rettype') <- - letGeneralise fname loc tparams' params'' rettype + letGeneralise fname loc tparams' params'' + =<< unscopeUnknowable rettype checkGlobalAliases params'' body_t loc @@ -1412,10 +1704,10 @@ sizeNamesPos (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = onParam t1 <> sizeNamesP onParam (Scalar Arrow {}) = mempty onParam (Scalar (Record fs)) = mconcat $ map onParam $ M.elems fs onParam (Scalar (TypeVar _ _ _ targs)) = mconcat $ map onTypeArg targs - onParam t = freeInType t - onTypeArg (TypeArgDim (NamedSize d) _) = S.singleton $ qualLeaf d - onTypeArg (TypeArgDim _ _) = mempty - onTypeArg (TypeArgType t _) = onParam t + onParam t = fvVars $ freeInType t + onTypeArg (TypeArgDim (SizeExpr (Var d _ _))) = S.singleton $ qualLeaf d + onTypeArg (TypeArgDim _) = mempty + onTypeArg (TypeArgType t) = onParam t sizeNamesPos _ = mempty checkGlobalAliases :: [Pat] -> PatType -> SrcLoc -> TermTypeM () @@ -1501,7 +1793,7 @@ verifyFunctionParams fname params = verifyParams (foldMap patNames params) =<< mapM updateTypes params where verifyParams forbidden (p : ps) - | d : _ <- S.toList $ freeInPat p `S.intersection` forbidden = + | d : _ <- S.toList $ fvVars (freeInPat p) `S.intersection` forbidden = typeError p mempty . withIndexLink "inaccessible-size" $ "Parameter" <+> dquotes (pretty p) @@ -1549,8 +1841,8 @@ injectExt ext ret = RetType ext_here $ deeper ret Scalar $ TypeVar as u tn $ map deeperArg targs deeper t@Array {} = t - deeperArg (TypeArgType t loc) = TypeArgType (deeper t) loc - deeperArg (TypeArgDim d loc) = TypeArgDim d loc + deeperArg (TypeArgType t) = TypeArgType $ deeper t + deeperArg (TypeArgDim d) = TypeArgDim d -- | Find all type variables in the given type that are covered by the -- constraints, and produce type parameters that close over them. @@ -1575,13 +1867,13 @@ closeOverTypes defname defloc tparams paramts ret substs = do _ -> Nothing pure ( tparams ++ more_tparams, - injectExt (retext ++ mapMaybe mkExt (S.toList $ freeInType ret)) ret + injectExt (retext ++ mapMaybe mkExt (M.keys $ unFV $ freeInType ret)) ret ) where -- Diet does not matter here. t = foldFunType (zip (repeat Observe) paramts) $ RetType [] ret to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs - visible = typeVars t <> freeInType t + visible = typeVars t <> fvVars (freeInType t) (produced_sizes, param_sizes) = dimUses t @@ -1598,7 +1890,7 @@ closeOverTypes defname defloc tparams paramts ret substs = do closeOver (k, UnknowableSize _ _) | k `S.member` param_sizes, k `S.notMember` produced_sizes = do - notes <- dimNotes defloc $ NamedSize $ qualName k + notes <- dimNotes defloc $ sizeVar (qualName k) mempty typeError defloc notes . withIndexLink "unknowable-param-def" $ "Unknowable size" <+> dquotes (prettyName k) @@ -1655,7 +1947,7 @@ letGeneralise defname defloc tparams params rettype = let used_sizes = foldMap freeInType $ rettype'' : map patternStructType params - case filter ((`S.notMember` used_sizes) . typeParamName) $ + case filter ((`M.notMember` unFV used_sizes) . typeParamName) $ filter isSizeParam tparams' of [] -> pure () tp : _ -> unusedSize $ SizeBinder (typeParamName tp) (srclocOf tp) @@ -1683,17 +1975,17 @@ checkFunBody params body maybe_rettype loc = do -- names into existential sizes instead. let hidden = hiddenParamNames params (body_t', _) <- - unscopeType + unscopePatType loc - ( M.filterWithKey (const . (`S.member` hidden)) $ - foldMap patternMap params + ( M.keysSet $ + M.filterWithKey (const . (`S.member` hidden)) $ + foldMap patternMap params ) body_t - let usage = mkUsage (srclocOf body) "return type annotation" + let usage = mkUsage body "return type annotation" onFailure (CheckingReturn rettype (toStruct body_t')) $ - expect usage rettype $ - toStruct body_t' + unify usage rettype (toStruct body_t') Nothing -> pure () pure body' diff --git a/src/Language/Futhark/TypeChecker/Terms/DoLoop.hs b/src/Language/Futhark/TypeChecker/Terms/DoLoop.hs index c26f8bc23b..b7861ca34c 100644 --- a/src/Language/Futhark/TypeChecker/Terms/DoLoop.hs +++ b/src/Language/Futhark/TypeChecker/Terms/DoLoop.hs @@ -35,9 +35,9 @@ getAreSame = check <$> getConstraints where check constraints x y = case (M.lookup x constraints, M.lookup y constraints) of - (Just (_, Size (Just (NamedSize x')) _), _) -> + (Just (_, Size (Just (Var x' _ _)) _), _) -> check constraints (qualLeaf x') y - (_, Just (_, Size (Just (NamedSize y')) _)) -> + (_, Just (_, Size (Just (Var y' _ _)) _)) -> check constraints x (qualLeaf y') _ -> x == y @@ -45,45 +45,43 @@ getAreSame = check <$> getConstraints -- | Replace specified sizes with distinct fresh size variables. someDimsFreshInType :: SrcLoc -> - Rigidity -> Name -> [VName] -> TypeBase Size als -> TermTypeM (TypeBase Size als) -someDimsFreshInType loc r desc fresh t = do +someDimsFreshInType loc desc fresh t = do areSameSize <- getAreSame let freshen v = any (areSameSize v) fresh bitraverse (onDim freshen) pure t where - onDim freshen (NamedSize d) + onDim freshen (SizeExpr (Var d _ _)) | freshen $ qualLeaf d = do - v <- newDimVar loc r desc - pure $ NamedSize $ qualName v + v <- newFlexibleDim (mkUsage' loc) desc + pure $ sizeFromName (qualName v) loc onDim _ d = pure d -- | Replace the specified sizes with fresh size variables of the -- specified ridigity. Returns the new fresh size variables. freshDimsInType :: - SrcLoc -> + Usage -> Rigidity -> Name -> [VName] -> TypeBase Size als -> TermTypeM (TypeBase Size als, [VName]) -freshDimsInType loc r desc fresh t = do +freshDimsInType usage r desc fresh t = do areSameSize <- getAreSame - let freshen v = any (areSameSize v) fresh - second M.elems <$> runStateT (bitraverse (onDim freshen) pure t) mempty + second (map snd) <$> runStateT (bitraverse (onDim areSameSize) pure t) mempty where - onDim freshen (NamedSize d) - | freshen $ qualLeaf d = do - prev_subst <- gets $ M.lookup $ qualLeaf d + onDim areSameSize (SizeExpr (Var (QualName _ d) _ _)) + | any (areSameSize d) fresh = do + prev_subst <- gets $ L.find (areSameSize d . fst) case prev_subst of - Just d' -> pure $ NamedSize $ qualName d' + Just (_, d') -> pure $ sizeFromName (qualName d') $ srclocOf usage Nothing -> do - v <- lift $ newDimVar loc r desc - modify $ M.insert (qualLeaf d) v - pure $ NamedSize $ qualName v + v <- lift $ newDimVar usage r desc + modify ((d, v) :) + pure $ sizeFromName (qualName v) $ srclocOf usage onDim _ d = pure d -- | Mark bindings of names in "consumed" as Unique. @@ -187,11 +185,11 @@ data ArgSource = Initial | BodyResult wellTypedLoopArg :: ArgSource -> [VName] -> Pat -> Exp -> TermTypeM () wellTypedLoopArg src sparams pat arg = do (merge_t, _) <- - freshDimsInType (srclocOf arg) Nonrigid "loop" sparams $ + freshDimsInType (mkUsage arg desc) Nonrigid "loop" sparams $ toStruct (patternType pat) arg_t <- toStruct <$> expTypeFully arg onFailure (checking merge_t arg_t) $ - unify (mkUsage (srclocOf arg) desc) merge_t arg_t + unify (mkUsage arg desc) merge_t arg_t where (checking, desc) = case src of @@ -222,7 +220,7 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = sequentially (checkExp mergeexp) $ \mergeexp' _ -> do known_before <- M.keysSet <$> getConstraints zeroOrderType - (mkUsage (srclocOf mergeexp) "use as loop variable") + (mkUsage mergeexp "use as loop variable") "type used as loop variable" . toStruct =<< expTypeFully mergeexp' @@ -256,7 +254,8 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = -- properly later. (merge_t, new_dims_map) <- -- dim handling (1) - allDimsFreshInType loc Nonrigid "loop_d" . flip setAliases mempty + allDimsFreshInType (mkUsage loc "loop parameter type inference") Nonrigid "loop_d" + . flip setAliases mempty =<< expTypeFully mergeexp' let new_dims_to_initial_dim = M.toList new_dims_map new_dims = map fst new_dims_to_initial_dim @@ -265,14 +264,14 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = let checkLoopReturnSize mergepat' loopbody' = do loopbody_t <- expTypeFully loopbody' pat_t <- - someDimsFreshInType loc Nonrigid "loop" new_dims + someDimsFreshInType loc "loop" new_dims =<< normTypeFully (patternType mergepat') -- We are ignoring the dimensions here, because any mismatches -- should be turned into fresh size variables. onFailure (CheckingLoopBody (toStruct pat_t) (toStruct loopbody_t)) $ unify - (mkUsage (srclocOf loopbody) "matching loop body to loop pattern") + (mkUsage loopbody "matching loop body to loop pattern") (toStruct pat_t) (toStruct loopbody_t) @@ -282,15 +281,19 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = areSameSize <- getAreSame let onDims _ x y | x == y = pure x - onDims _ (NamedSize v) d - | Just (v', d') <- - L.find (areSameSize (qualLeaf v) . fst) new_dims_to_initial_dim = do - if d' == d - then modify $ first $ M.insert v' (SizeSubst d) - else - unless (qualLeaf v `S.member` known_before) $ - modify (second (qualLeaf v :)) - pure $ NamedSize v + onDims _ x@(SizeExpr e) d = do + let vs = M.keys . unFV $ freeInExp e + forM_ vs $ \v -> do + case L.find (areSameSize v . fst) new_dims_to_initial_dim of + Just (_, d'@(SizeExpr e')) -> + if d' == d + then modify $ first $ M.insert v $ ExpSubst e' + else + unless (v `S.member` known_before) $ + modify (second (v :)) + _ -> + pure () + pure x onDims _ x _ = pure x loopbody_t' <- normTypeFully loopbody_t merge_t' <- normTypeFully merge_t @@ -302,8 +305,8 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = -- replaced with the invariant size in the loop body. Failure -- to do this can cause type annotations to still refer to -- new_dims. - let dimToInit (v, SizeSubst d) = - constrain v $ Size (Just d) (mkUsage loc "size of loop parameter") + let dimToInit (v, ExpSubst e) = + constrain v $ Size (Just e) (mkUsage loc "size of loop parameter") dimToInit _ = pure () mapM_ dimToInit $ M.toList init_substs @@ -350,7 +353,7 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = loopbody' ) ForIn xpat e -> do - (arr_t, _) <- newArrayType (srclocOf e) "e" 1 + (arr_t, _) <- newArrayType (mkUsage' (srclocOf e)) "e" 1 e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e t <- expTypeFully e' case t of @@ -392,7 +395,7 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = mergepat'' <- do loopbody_t <- expTypeFully loopbody' convergePat loc mergepat' (allConsumed bodyflow) loopbody_t $ - mkUsage (srclocOf loopbody') "being (part of) the result of the loop body" + mkUsage loopbody' "being (part of) the result of the loop body" merge_t' <- expTypeFully mergeexp' let consumeMerge (Id _ (Info pt) ploc) mt @@ -412,8 +415,12 @@ checkDoLoop checkExp (mergepat, mergeexp, form, loopbody) loc = wellTypedLoopArg Initial sparams mergepat'' mergeexp' (loopt, retext) <- - freshDimsInType loc (Rigid RigidLoop) "loop" sparams $ - loopReturnType mergepat'' merge_t' + freshDimsInType + (mkUsage loc "inference of loop result type") + (Rigid RigidLoop) + "loop" + sparams + $ loopReturnType mergepat'' merge_t' -- We set all of the uniqueness to be unique. This is intentional, -- and matches what happens for function calls. Those arrays that -- really *cannot* be consumed will alias something unconsumable, diff --git a/src/Language/Futhark/TypeChecker/Terms/Monad.hs b/src/Language/Futhark/TypeChecker/Terms/Monad.hs index 2db6d36589..b6d64f7498 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Monad.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Monad.hs @@ -38,7 +38,6 @@ module Language.Futhark.TypeChecker.Terms.Monad isInt64, maybeDimFromExp, dimFromExp, - sizeFromArg, noSizeEscape, -- * Control flow @@ -94,12 +93,12 @@ import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types -import Language.Futhark.TypeChecker.Unify hiding (Usage) +import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) --- Uniqueness -data Usage +data VarUse = Consumed SrcLoc | Observed SrcLoc deriving (Eq, Ord, Show) @@ -139,7 +138,7 @@ seminullOccurrence occ = S.null (observed occ) && maybe True S.null (consumed oc type Occurrences = [Occurrence] -type UsageMap = M.Map VName [Usage] +type UsageMap = M.Map VName [VarUse] usageMap :: Occurrences -> UsageMap usageMap = foldl comb M.empty @@ -149,7 +148,7 @@ usageMap = foldl comb M.empty in S.foldl' (ins $ Consumed loc) m' $ fromMaybe mempty cons ins v m k = M.insertWith (++) k [v] m -combineOccurrences :: VName -> Usage -> Usage -> TermTypeM Usage +combineOccurrences :: VName -> VarUse -> VarUse -> TermTypeM VarUse combineOccurrences _ (Observed loc) (Observed _) = pure $ Observed loc combineOccurrences name (Consumed wloc) (Observed rloc) = useAfterConsume name rloc wloc @@ -378,7 +377,8 @@ instance Pretty Checking where data TermEnv = TermEnv { termScope :: TermScope, termChecking :: Maybe Checking, - termLevel :: Level + termLevel :: Level, + termChecker :: UncheckedExp -> TermTypeM Exp } data TermScope = TermScope @@ -504,11 +504,11 @@ instance MonadUnify TermTypeM where curLevel = asks termLevel - newDimVar loc rigidity name = do + newDimVar usage rigidity name = do dim <- newTypeName name case rigidity of - Rigid rsrc -> constrain dim $ UnknowableSize loc rsrc - Nonrigid -> constrain dim $ Size Nothing $ mkUsage' loc + Rigid rsrc -> constrain dim $ UnknowableSize (srclocOf usage) rsrc + Nonrigid -> constrain dim $ Size Nothing usage pure dim unifyError loc notes bcs doc = do @@ -579,7 +579,7 @@ instantiateTypeParam qn loc tparam = do TypeParamDim {} -> do constrain v . Size Nothing . mkUsage loc . docText $ "instantiated size parameter of " <> dquotes (pretty qn) - pure (v, SizeSubst $ NamedSize $ qualName v) + pure (v, ExpSubst $ sizeVar (qualName v) loc) checkQualNameWithEnv :: Namespace -> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName) checkQualNameWithEnv space qn@(QualName quals name) loc = do @@ -622,6 +622,13 @@ localScope :: (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} instance MonadTypeChecker TermTypeM where + checkExpForSize e = do + checker <- asks termChecker + e' <- noUnique $ checker e + let t = toStruct $ typeOf e' + unify (mkUsage (srclocOf e') "Size expression") t (Scalar (Prim (Signed Int64))) + updateTypes e' + warn loc problem = liftTypeM $ warn loc problem newName = liftTypeM . newName newID = liftTypeM . newID @@ -699,15 +706,6 @@ instance MonadTypeChecker TermTypeM where maybe (toStruct argtype) (Scalar . Prim) rt ) - checkNamedSize loc v = do - (v', t) <- lookupVar loc v - onFailure (CheckingRequired [Scalar $ Prim $ Signed Int64] (toStruct t)) $ - unify (mkUsage loc "use as array size") (toStruct t) $ - Scalar $ - Prim $ - Signed Int64 - pure v' - typeError loc notes s = do checking <- asks termChecking case checking of @@ -731,15 +729,15 @@ extSize loc e = do RigidBound $ prettyTextOneLine e' SourceSlice d i j s -> RigidSlice d $ prettyTextOneLine $ DimSlice i j s - d <- newDimVar loc (Rigid rsrc) "n" + d <- newRigidDim loc rsrc "n" modify $ \s -> s {stateDimTable = M.insert e d $ stateDimTable s} pure - ( NamedSize $ qualName d, + ( sizeFromName (qualName d) loc, Just d ) Just d -> pure - ( NamedSize $ qualName d, + ( sizeFromName (qualName d) loc, Just d ) @@ -759,31 +757,32 @@ expType = normPatType . typeOf expTypeFully :: Exp -> TermTypeM PatType expTypeFully = normTypeFully . typeOf -newArrayType :: SrcLoc -> Name -> Int -> TermTypeM (StructType, StructType) -newArrayType loc desc r = do +newArrayType :: Usage -> Name -> Int -> TermTypeM (StructType, StructType) +newArrayType usage desc r = do v <- newTypeName desc - constrain v $ NoConstraint Unlifted $ mkUsage' loc - dims <- replicateM r $ newDimVar loc Nonrigid "dim" + constrain v $ NoConstraint Unlifted usage + dims <- replicateM r $ newDimVar usage Nonrigid "dim" let rowt = TypeVar () Nonunique (qualName v) [] + mkSize = flip sizeFromName (srclocOf usage) . qualName pure - ( Array () Nonunique (Shape $ map (NamedSize . qualName) dims) rowt, + ( Array () Nonunique (Shape $ map mkSize dims) rowt, Scalar rowt ) -- | Replace *all* dimensions with distinct fresh size variables. allDimsFreshInType :: - SrcLoc -> + Usage -> Rigidity -> Name -> TypeBase Size als -> TermTypeM (TypeBase Size als, M.Map VName Size) -allDimsFreshInType loc r desc t = +allDimsFreshInType usage r desc t = runStateT (bitraverse onDim pure t) mempty where onDim d = do - v <- lift $ newDimVar loc r desc + v <- lift $ newDimVar usage r desc modify $ M.insert v d - pure $ NamedSize $ qualName v + pure $ sizeFromName (qualName v) $ srclocOf usage -- | Replace all type variables with their concrete types. updateTypes :: ASTMappable e => e -> TermTypeM e @@ -827,7 +826,7 @@ termCheckTypeExp te = do -- Observe the sizes so we do not get any warnings about them not -- being used. - mapM_ observeDim $ freeInType st + mapM_ observeDim $ fvVars $ freeInType st pure (te', svars, RetType dims st) where observeDim v = @@ -856,13 +855,14 @@ isInt64 :: Exp -> Maybe Int64 isInt64 (Literal (SignedValue (Int64Value k')) _) = Just $ fromIntegral k' isInt64 (IntLit k' _ _) = Just $ fromInteger k' isInt64 (Negate x _) = negate <$> isInt64 x +isInt64 (Parens x _) = isInt64 x isInt64 _ = Nothing maybeDimFromExp :: Exp -> Maybe Size -maybeDimFromExp (Var v _ _) = Just $ NamedSize v +maybeDimFromExp (Var v typ loc) = Just $ SizeExpr $ Var v typ loc maybeDimFromExp (Parens e _) = maybeDimFromExp e maybeDimFromExp (QualParens _ e _) = maybeDimFromExp e -maybeDimFromExp e = ConstSize . fromIntegral <$> isInt64 e +maybeDimFromExp e = flip sizeFromInteger mempty . fromIntegral <$> isInt64 e dimFromExp :: (Exp -> SizeSource) -> Exp -> TermTypeM (Size, Maybe VName) dimFromExp rf (Attr _ e _) = dimFromExp rf e @@ -875,9 +875,6 @@ dimFromExp rf e | otherwise = extSize (srclocOf e) $ rf e -sizeFromArg :: Maybe (QualName VName) -> Exp -> TermTypeM (Size, Maybe VName) -sizeFromArg fname = dimFromExp $ SourceArg (FName fname) . bareExp - -- | Any argument sizes created with 'extSize' inside the given action -- will be removed once the action finishes. This is to ensure that -- just because e.g. @n+1@ appears as a size in one branch of a @@ -1024,14 +1021,15 @@ initialTermScope = Just (name, EqualityF) addIntrinsicF _ = Nothing -runTermTypeM :: TermTypeM a -> TypeM (a, Occurrences) -runTermTypeM (TermTypeM m) = do +runTermTypeM :: (UncheckedExp -> TermTypeM Exp) -> TermTypeM a -> TypeM (a, Occurrences) +runTermTypeM checker (TermTypeM m) = do initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv let initial_tenv = TermEnv { termScope = initial_scope, termChecking = Nothing, - termLevel = 0 + termLevel = 0, + termChecker = checker } second stateOccs <$> runStateT diff --git a/src/Language/Futhark/TypeChecker/Terms/Pat.hs b/src/Language/Futhark/TypeChecker/Terms/Pat.hs index 7a8e3ed9c7..8d004b9ee7 100644 --- a/src/Language/Futhark/TypeChecker/Terms/Pat.hs +++ b/src/Language/Futhark/TypeChecker/Terms/Pat.hs @@ -34,17 +34,17 @@ nonrigidFor :: [SizeBinder VName] -> StructType -> TermTypeM StructType nonrigidFor [] t = pure t -- Minor optimisation. nonrigidFor sizes t = evalStateT (bitraverse onDim pure t) mempty where - onDim (NamedSize (QualName _ v)) + onDim (SizeExpr (Var (QualName _ v) typ loc)) | Just size <- find ((== v) . sizeName) sizes = do prev <- gets $ lookup v case prev of Nothing -> do v' <- lift $ newID $ baseName v - lift $ constrain v' $ Size Nothing $ mkUsage' $ srclocOf size + lift $ constrain v' $ Size Nothing $ mkUsage size "ambiguous size of bound expression" modify ((v, v') :) - pure $ NamedSize $ qualName v' + pure $ SizeExpr $ Var (qualName v') typ loc Just v' -> - pure $ NamedSize $ qualName v' + pure $ SizeExpr $ Var (qualName v') typ loc onDim d = pure d -- | The set of in-scope variables that are being aliased. @@ -242,8 +242,8 @@ bindingPat sizes p t m = do Ident v (Info (Scalar $ Prim $ Signed Int64)) loc mapM_ (observe . ident) sizes - let used_sizes = freeInType $ patternStructType p' - case filter ((`S.notMember` used_sizes) . sizeName) sizes of + let used_sizes = unFV $ freeInType $ patternStructType p' + case filter ((`M.notMember` used_sizes) . sizeName) sizes of [] -> m p' size : _ -> unusedSize size diff --git a/src/Language/Futhark/TypeChecker/Types.hs b/src/Language/Futhark/TypeChecker/Types.hs index 454189c546..d0b026d20c 100644 --- a/src/Language/Futhark/TypeChecker/Types.hs +++ b/src/Language/Futhark/TypeChecker/Types.hs @@ -41,13 +41,13 @@ mustBeExplicitAux :: StructType -> M.Map VName Bool mustBeExplicitAux t = execState (traverseDims onDim t) mempty where - onDim bound _ (NamedSize d) + onDim bound _ (SizeExpr (Var d _ _)) | qualLeaf d `S.member` bound = modify $ \s -> M.insertWith (&&) (qualLeaf d) False s - onDim _ PosImmediate (NamedSize d) = + onDim _ PosImmediate (SizeExpr (Var d _ _)) = modify $ \s -> M.insertWith (&&) (qualLeaf d) False s - onDim _ _ (NamedSize d) = - modify $ M.insertWith (&&) (qualLeaf d) True + onDim _ _ (SizeExpr e) = + modify $ M.unionWith (&&) (M.map (const True) (unFV $ freeInExp e)) onDim _ _ _ = pure () @@ -68,11 +68,7 @@ determineSizeWitnesses t = mustBeExplicitInBinding :: StructType -> S.Set VName mustBeExplicitInBinding bind_t = let (ts, ret) = unfoldFunType bind_t - alsoRet = - M.unionWith (&&) $ - M.fromList $ - zip (S.toList $ freeInType ret) $ - repeat True + alsoRet = M.unionWith (&&) $ M.map (const True) $ unFV $ freeInType ret in S.fromList $ M.keys $ M.filter id $ alsoRet $ foldl' onType mempty $ map snd ts where onType uses t = uses <> mustBeExplicitAux t -- Left-biased union. @@ -183,10 +179,10 @@ unifyScalarTypes uf (TypeVar als1 u1 tv1 targs1) (TypeVar als2 u2 tv2 targs2) Just $ TypeVar (als1 <> als2) u3 tv1 targs3 | otherwise = Nothing where - unifyTypeArgs (TypeArgDim d1 loc) (TypeArgDim _d2 _) = - pure $ TypeArgDim d1 loc - unifyTypeArgs (TypeArgType t1 loc) (TypeArgType t2 _) = - TypeArgType <$> unifyTypesU uf t1 t2 <*> pure loc + unifyTypeArgs (TypeArgDim d1) (TypeArgDim _d2) = + pure $ TypeArgDim d1 + unifyTypeArgs (TypeArgType t1) (TypeArgType t2) = + TypeArgType <$> unifyTypesU uf t1 t2 unifyTypeArgs _ _ = Nothing unifyScalarTypes uf (Record ts1) (Record ts2) @@ -231,33 +227,13 @@ renameRetType :: MonadTypeChecker m => StructRetType -> m StructRetType renameRetType (RetType dims st) | dims /= mempty = do dims' <- mapM newName dims - let m = M.fromList $ zip dims $ map (SizeSubst . NamedSize . qualName) dims' + let mkSubst = ExpSubst . flip sizeVar mempty . qualName + m = M.fromList . zip dims $ map mkSubst dims' st' = applySubst (`M.lookup` m) st pure $ RetType dims' st' | otherwise = pure $ RetType dims st -checkExpForSize :: - MonadTypeChecker m => - ExpBase NoInfo Name -> - m (Exp, Size) -checkExpForSize (IntLit x NoInfo loc) = - pure (IntLit x int64_info loc, ConstSize $ fromInteger x) - where - int64_info = Info (Scalar (Prim (Signed Int64))) -checkExpForSize (Literal (SignedValue (Int64Value x)) loc) = - pure (Literal (SignedValue (Int64Value x)) loc, ConstSize x) -checkExpForSize (Var v NoInfo vloc) = do - v' <- checkNamedSize vloc v - pure (Var v' int64_info vloc, NamedSize v') - where - int64_info = Info (Scalar (Prim (Signed Int64))) -checkExpForSize e = - typeError - (locOf e) - mempty - "Only variables and i64 literals are allowed in size expressions." - evalTypeExp :: MonadTypeChecker m => TypeExp NoInfo Name -> @@ -329,10 +305,10 @@ evalTypeExp (TEArray d t loc) = do where checkSizeExp (SizeExpAny dloc) = do dv <- newTypeName "d" - pure ([dv], SizeExpAny dloc, NamedSize $ qualName dv) + pure ([dv], SizeExpAny dloc, sizeFromName (qualName dv) dloc) checkSizeExp (SizeExp e dloc) = do - (e', sz) <- checkExpForSize e - pure ([], SizeExp e' dloc, sz) + e' <- checkExpForSize e + pure ([], SizeExp e' dloc, SizeExpr e') -- evalTypeExp (TEUnique t loc) = do (t', svars, RetType dims st, l) <- evalTypeExp t @@ -457,18 +433,18 @@ evalTypeExp ote@TEApply {} = do "Type" <+> dquotes (pretty te') <+> "is not a type constructor." checkSizeExp (SizeExp e dloc) = do - (e', sz) <- checkExpForSize e + e' <- checkExpForSize e pure ( TypeArgExpSize (SizeExp e' dloc), [], - SizeSubst sz + ExpSubst e' ) checkSizeExp (SizeExpAny loc) = do d <- newTypeName "d" pure ( TypeArgExpSize (SizeExpAny loc), [d], - SizeSubst $ NamedSize $ qualName d + ExpSubst $ sizeVar (qualName d) loc ) checkArgApply (TypeParamDim pv _) (TypeArgExpSize d) = do @@ -618,22 +594,22 @@ checkTypeParams ps m = -- | Construct a type argument corresponding to a type parameter. typeParamToArg :: TypeParam -> StructTypeArg typeParamToArg (TypeParamDim v ploc) = - TypeArgDim (NamedSize $ qualName v) ploc -typeParamToArg (TypeParamType _ v ploc) = - TypeArgType (Scalar $ TypeVar () Nonunique (qualName v) []) ploc + TypeArgDim $ sizeFromName (qualName v) ploc +typeParamToArg (TypeParamType _ v _) = + TypeArgType $ Scalar $ TypeVar () Nonunique (qualName v) [] -- | A type substitution may be a substitution or a yet-unknown -- substitution (but which is certainly an overloaded primitive -- type!). The latter is used to remove aliases from types that are -- yet-unknown but that we know cannot carry aliases (see issue #682). -data Subst t = Subst [TypeParam] t | PrimSubst | SizeSubst Size +data Subst t = Subst [TypeParam] t | PrimSubst | ExpSubst Exp deriving (Show) instance Pretty t => Pretty (Subst t) where pretty (Subst [] t) = pretty t pretty (Subst tps t) = mconcat (map pretty tps) <> colon <+> pretty t pretty PrimSubst = "#" - pretty (SizeSubst d) = pretty d + pretty (ExpSubst e) = pretty e -- | Create a type substitution corresponding to a type binding. substFromAbbr :: TypeBinding -> Subst StructRetType @@ -645,7 +621,7 @@ type TypeSubs = VName -> Maybe (Subst StructRetType) instance Functor Subst where fmap f (Subst ps t) = Subst ps $ f t fmap _ PrimSubst = PrimSubst - fmap _ (SizeSubst v) = SizeSubst v + fmap _ (ExpSubst e) = ExpSubst e -- | Class of types which allow for substitution of types with no -- annotations for type variable names. @@ -670,10 +646,35 @@ instance Substitutable (TypeBase Size ()) where instance Substitutable (TypeBase Size Aliasing) where applySubst = substTypesAny . (fmap (fmap (second (const mempty))) .) +instance Substitutable Exp where + applySubst f = runIdentity . mapOnExp + where + mapOnExp (Var (QualName _ v) _ _) + | Just (ExpSubst e') <- f v = pure e' + mapOnExp e' = astMap mapper e' + + mapper = + ASTMapper + { mapOnExp, + mapOnName = pure, + mapOnStructType = pure . applySubst f, + mapOnPatType = pure . applySubst f, + mapOnStructRetType = pure . applySubst f, + mapOnPatRetType = pure . applySubst f + } + instance Substitutable Size where - applySubst f (NamedSize (QualName _ v)) - | Just (SizeSubst d) <- f v = d - applySubst _ d = d + applySubst f size = runIdentity $ astMap mapper size + where + mapper = + ASTMapper + { mapOnExp = pure . applySubst f, + mapOnName = pure, + mapOnStructType = pure . applySubst f, + mapOnPatType = pure . applySubst f, + mapOnStructRetType = pure . applySubst f, + mapOnPatRetType = pure . applySubst f + } instance Substitutable d => Substitutable (Shape d) where applySubst f = fmap $ applySubst f @@ -683,7 +684,7 @@ instance Substitutable Pat where where mapper = ASTMapper - { mapOnExp = pure, + { mapOnExp = pure . applySubst f, mapOnName = pure, mapOnStructType = pure . applySubst f, mapOnPatType = pure . applySubst f, @@ -701,9 +702,9 @@ applyType ps t args = substTypesAny (`M.lookup` substs) t where substs = M.fromList $ zipWith mkSubst ps args -- We are assuming everything has already been type-checked for correctness. - mkSubst (TypeParamDim pv _) (TypeArgDim d _) = - (pv, SizeSubst d) - mkSubst (TypeParamType _ pv _) (TypeArgType at _) = + mkSubst (TypeParamDim pv _) (TypeArgDim (SizeExpr e)) = + (pv, ExpSubst e) + mkSubst (TypeParamType _ pv _) (TypeArgType at) = (pv, Subst [] $ RetType [] $ second mempty at) mkSubst p a = error $ "applyType mkSubst: cannot substitute " ++ prettyString a ++ " for " ++ prettyString p @@ -733,7 +734,8 @@ substTypesRet lookupSubst ot = else do let start = maximum $ map baseTag seen_ext ext' = zipWith VName (map baseName ext) [start + 1 ..] - extsubsts = M.fromList $ zip ext $ map (SizeSubst . NamedSize . qualName) ext' + mkSubst = ExpSubst . flip sizeVar mempty . qualName + extsubsts = M.fromList $ zip ext $ map mkSubst ext' RetType [] t' = substTypesRet (`M.lookup` extsubsts) t pure $ RetType ext' t' @@ -779,12 +781,12 @@ substTypesRet lookupSubst ot = _ -> pure $ RetType (new_ext <> dims) t' - subsTypeArg (TypeArgType t loc) = do + subsTypeArg (TypeArgType t) = do let RetType dims t' = substTypesRet lookupSubst' t modify (dims ++) - pure $ TypeArgType t' loc - subsTypeArg (TypeArgDim v loc) = - pure $ TypeArgDim (applySubst lookupSubst' v) loc + pure $ TypeArgType t' + subsTypeArg (TypeArgDim v) = + pure $ TypeArgDim $ applySubst lookupSubst' v lookupSubst' = fmap (fmap $ second (const ())) . lookupSubst @@ -803,7 +805,7 @@ substTypesAny lookupSubst ot = -- AnySize. This should _never_ happen during type-checking, but -- may happen as we substitute types during monomorphisation and -- defunctorisation later on. See Note [AnySize] - let toAny (NamedSize v) + let toAny (SizeExpr (Var v _ _)) | qualLeaf v `elem` dims = AnySize Nothing toAny d = d in first toAny ot' diff --git a/src/Language/Futhark/TypeChecker/Unify.hs b/src/Language/Futhark/TypeChecker/Unify.hs index 3fc7448a95..d59973b94f 100644 --- a/src/Language/Futhark/TypeChecker/Unify.hs +++ b/src/Language/Futhark/TypeChecker/Unify.hs @@ -22,9 +22,7 @@ module Language.Futhark.TypeChecker.Unify equalityType, normPatType, normTypeFully, - instantiateEmptyArrayDims, unify, - expect, unifyMostCommon, doUnification, ) @@ -33,8 +31,6 @@ where import Control.Monad import Control.Monad.Except import Control.Monad.State -import Data.Bifunctor -import Data.Char (isAscii) import Data.List (foldl', intersect) import Data.Map.Strict qualified as M import Data.Maybe @@ -97,13 +93,13 @@ data Usage = Usage (Maybe T.Text) SrcLoc deriving (Show) -- | Construct a 'Usage' from a location and a description. -mkUsage :: SrcLoc -> T.Text -> Usage -mkUsage = flip (Usage . Just) +mkUsage :: Located a => a -> T.Text -> Usage +mkUsage = flip (Usage . Just) . srclocOf -- | Construct a 'Usage' that has just a location, but no particular -- description. -mkUsage' :: SrcLoc -> Usage -mkUsage' = Usage Nothing +mkUsage' :: Located a => a -> Usage +mkUsage' = Usage Nothing . srclocOf instance Pretty Usage where pretty (Usage Nothing loc) = "use at " <> textwrap (locText loc) @@ -129,7 +125,7 @@ data Constraint | ParamSize SrcLoc | -- | Is not actually a type, but a term-level size, -- possibly already set to something specific. - Size (Maybe Size) Usage + Size (Maybe Exp) Usage | -- | A size that does not unify with anything - -- created from the result of applying a function -- whose return size is existential, or otherwise @@ -160,7 +156,7 @@ lookupSubst v constraints = case snd <$> M.lookup v constraints of Just (Constraint t _) -> Just $ Subst [] $ applySubst (`lookupSubst` constraints) t Just Overloaded {} -> Just PrimSubst Just (Size (Just d) _) -> - Just $ SizeSubst $ applySubst (`lookupSubst` constraints) d + Just $ ExpSubst $ applySubst (`lookupSubst` constraints) d _ -> Nothing -- | The source of a rigid size. @@ -251,8 +247,8 @@ prettySource ctx loc (RigidCond t1 t2) = -- | Retrieve notes describing the purpose or origin of the given -- t'Size'. The location is used as the *current* location, for the -- purpose of reporting relative locations. -dimNotes :: (Located a, MonadUnify m) => a -> Size -> m Notes -dimNotes ctx (NamedSize d) = do +dimNotes :: (Located a, MonadUnify m) => a -> Exp -> m Notes +dimNotes ctx (Var d _ _) = do c <- M.lookup (qualLeaf d) <$> getConstraints case c of Just (_, UnknowableSize loc rsrc) -> @@ -264,8 +260,9 @@ dimNotes _ _ = pure mempty typeNotes :: (Located a, MonadUnify m) => a -> StructType -> m Notes typeNotes ctx = fmap mconcat - . mapM (dimNotes ctx . NamedSize . qualName) - . S.toList + . mapM (dimNotes ctx . flip sizeVar mempty . qualName) + . M.keys + . unFV . freeInType typeVarNotes :: MonadUnify m => VName -> m Notes @@ -300,7 +297,11 @@ class Monad m => MonadUnify m where putConstraints $ f x newTypeVar :: Monoid als => SrcLoc -> Name -> m (TypeBase dim als) - newDimVar :: SrcLoc -> Rigidity -> Name -> m VName + newDimVar :: Usage -> Rigidity -> Name -> m VName + newRigidDim :: Located a => a -> RigidSource -> Name -> m VName + newRigidDim loc = newDimVar (mkUsage' loc) . Rigid + newFlexibleDim :: Usage -> Name -> m VName + newFlexibleDim usage = newDimVar usage Nonrigid curLevel :: m Level @@ -360,22 +361,6 @@ unsharedConstructorsMsg cs1 cs2 = filter (`notElem` M.keys cs1) (M.keys cs2) ++ filter (`notElem` M.keys cs2) (M.keys cs1) --- | Instantiate existential context in return type. -instantiateEmptyArrayDims :: - MonadUnify m => - SrcLoc -> - Rigidity -> - RetTypeBase Size als -> - m (TypeBase Size als, [VName]) -instantiateEmptyArrayDims tloc r (RetType dims t) = do - dims' <- mapM new dims - pure (first (onDim $ zip dims dims') t, dims') - where - new = newDimVar tloc r . nameFromString . takeWhile isAscii . baseString - onDim dims' (NamedSize d) = - NamedSize $ maybe d qualName (lookup (qualLeaf d) dims') - onDim _ d = d - -- | Is the given type variable the name of an abstract type or type -- parameter, which we cannot substitute? isRigid :: VName -> Constraints -> Bool @@ -389,16 +374,16 @@ isNonRigid v constraints = do guard $ not $ rigidConstraint c pure lvl -type UnifyDims m = - BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> Size -> Size -> m () +type UnifySizes m = + BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> Exp -> Exp -> m () -flipUnifyDims :: UnifyDims m -> UnifyDims m -flipUnifyDims onDims bcs bound nonrigid t1 t2 = +flipUnifySizes :: UnifySizes m -> UnifySizes m +flipUnifySizes onDims bcs bound nonrigid t1 t2 = onDims bcs bound nonrigid t2 t1 unifyWith :: MonadUnify m => - UnifyDims m -> + UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -426,12 +411,12 @@ unifyWith onDims usage = subunify False -- We may have to flip the order of future calls to -- onDims inside linkVarToType. linkDims - | ord' = flipUnifyDims onDims + | ord' = flipUnifySizes onDims | otherwise = onDims - unifyTypeArg bcs' (TypeArgDim d1 _) (TypeArgDim d2 _) = + unifyTypeArg bcs' (TypeArgDim (SizeExpr d1)) (TypeArgDim (SizeExpr d2)) = onDims' bcs' (swap ord d1 d2) - unifyTypeArg bcs' (TypeArgType t _) (TypeArgType arg_t _) = + unifyTypeArg bcs' (TypeArgType t) (TypeArgType arg_t) = subunify ord bound bcs' t arg_t unifyTypeArg bcs' _ _ = unifyError @@ -532,7 +517,7 @@ unifyWith onDims usage = subunify False case (p1, p2) of (Named p1', Named p2') -> let f v - | v == p2' = Just $ SizeSubst $ NamedSize $ qualName p1' + | v == p2' = Just $ ExpSubst $ sizeVar (qualName p1') mempty | otherwise = Nothing in (b1, applySubst f b2) (_, _) -> @@ -541,8 +526,8 @@ unifyWith onDims usage = subunify False pname (Named x) = Just x pname Unnamed = Nothing (Array {}, Array {}) - | Shape (t1_d : _) <- arrayShape t1', - Shape (t2_d : _) <- arrayShape t2', + | Shape (SizeExpr t1_d : _) <- arrayShape t1', + Shape (SizeExpr t2_d : _) <- arrayShape t2', Just t1'' <- peelArray 1 t1', Just t2'' <- peelArray 1 t2' -> do onDims' bcs (swap ord t1_d t2_d) @@ -558,55 +543,37 @@ unifyWith onDims usage = subunify False | t1' == t2' -> pure () | otherwise -> failure -unifyDims :: MonadUnify m => Usage -> UnifyDims m -unifyDims _ _ _ _ d1 d2 - | d1 == d2 = pure () -unifyDims usage bcs _ nonrigid (NamedSize (QualName _ d1)) d2 - | Just lvl1 <- nonrigid d1 = - linkVarToDim usage bcs d1 lvl1 d2 -unifyDims usage bcs _ nonrigid d1 (NamedSize (QualName _ d2)) - | Just lvl2 <- nonrigid d2 = - linkVarToDim usage bcs d2 lvl2 d1 -unifyDims usage bcs _ _ d1 d2 = do - notes <- (<>) <$> dimNotes usage d1 <*> dimNotes usage d2 +anyBound :: [VName] -> ExpBase Info VName -> Bool +anyBound bound e = any (`S.member` fvVars (freeInExp e)) bound + +unifySizes :: MonadUnify m => Usage -> UnifySizes m +unifySizes usage bcs bound nonrigid e1 e2 + | Just es <- similarExps e1 e2 = + mapM_ (uncurry $ unifySizes usage bcs bound nonrigid) es +unifySizes usage bcs bound nonrigid (Var v1 _ _) e2 + | Just lvl1 <- nonrigid (qualLeaf v1), + not (anyBound bound e2) || (qualLeaf v1 `elem` bound) = + linkVarToDim usage bcs (qualLeaf v1) lvl1 e2 +unifySizes usage bcs bound nonrigid e1 (Var v2 _ _) + | Just lvl2 <- nonrigid (qualLeaf v2), + not (anyBound bound e1) || (qualLeaf v2 `elem` bound) = + linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 +unifySizes usage bcs _ _ e1 e2 = do + notes <- (<>) <$> dimNotes usage e2 <*> dimNotes usage e2 unifyError usage notes bcs $ - "Dimensions" - <+> dquotes (pretty d1) + "Sizes" + <+> dquotes (pretty e1) <+> "and" - <+> dquotes (pretty d2) + <+> dquotes (pretty e2) <+> "do not match." -- | Unifies two types. unify :: MonadUnify m => Usage -> StructType -> StructType -> m () -unify usage = unifyWith (unifyDims usage) usage mempty noBreadCrumbs +unify usage = unifyWith (unifySizes usage) usage mempty noBreadCrumbs -- | @expect super sub@ checks that @sub@ is a subtype of @super@. expect :: MonadUnify m => Usage -> StructType -> StructType -> m () -expect usage = unifyWith onDims usage mempty noBreadCrumbs - where - onDims _ _ _ d1 d2 - | d1 == d2 = pure () - -- We identify existentially bound names by them being nonrigid - -- and yet bound. It's OK to unify with those. - onDims bcs bound nonrigid (NamedSize (QualName _ d1)) d2 - | Just lvl1 <- nonrigid d1, - not (boundParam bound d2) || (d1 `elem` bound) = - linkVarToDim usage bcs d1 lvl1 d2 - onDims bcs bound nonrigid d1 (NamedSize (QualName _ d2)) - | Just lvl2 <- nonrigid d2, - not (boundParam bound d1) || (d2 `elem` bound) = - linkVarToDim usage bcs d2 lvl2 d1 - onDims bcs _ _ d1 d2 = do - notes <- (<>) <$> dimNotes usage d1 <*> dimNotes usage d2 - unifyError usage notes bcs $ - "Dimensions" - <+> dquotes (pretty d1) - <+> "and" - <+> dquotes (pretty d2) - <+> "do not match." - - boundParam bound (NamedSize (QualName _ d)) = d `elem` bound - boundParam _ _ = False +expect usage = unifyWith (unifySizes usage) usage mempty noBreadCrumbs occursCheck :: MonadUnify m => @@ -636,7 +603,7 @@ scopeCheck usage bcs vn max_lvl tp = do checkType constraints tp where checkType constraints t = - mapM_ (check constraints) $ typeVars t <> freeInType t + mapM_ (check constraints) $ typeVars t <> fvVars (freeInType t) check constraints v | Just (lvl, c) <- M.lookup v constraints, @@ -661,7 +628,7 @@ scopeCheck usage bcs vn max_lvl tp = do linkVarToType :: MonadUnify m => - UnifyDims m -> + UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -706,7 +673,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do link arrayElemTypeWith usage (unliftedBcs unlift_usage) tp - when (any (`elem` bound) (freeInType tp)) $ + when (any (`elem` bound) (fvVars (freeInType tp))) $ unifyError usage mempty bcs $ "Type variable" <+> prettyName vn @@ -742,7 +709,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do | all (`M.member` tp_fields) $ M.keys required_fields -> do required_fields' <- mapM normTypeFully required_fields let tp' = Scalar $ Record $ required_fields <> tp_fields -- Crucially left-biased. - ext = filter (`S.member` freeInType tp') bound + ext = filter (`M.member` (unFV $ freeInType tp')) bound modifyConstraints $ M.insert vn (lvl, Constraint (RetType ext tp') usage) unifySharedFields onDims usage bound bcs required_fields' tp_fields @@ -784,7 +751,7 @@ linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do Scalar (Sum ts) | all (`M.member` ts) $ M.keys required_cs -> do let tp' = Scalar $ Sum $ required_cs <> ts -- Crucially left-biased. - ext = filter (`S.member` freeInType tp') bound + ext = filter (`M.member` (unFV $ freeInType tp')) bound modifyConstraints $ M.insert vn (lvl, Constraint (RetType ext tp') usage) unifySharedConstructors onDims usage bound bcs required_cs ts @@ -837,31 +804,41 @@ linkVarToDim :: BreadCrumbs -> VName -> Level -> - Size -> + Exp -> m () -linkVarToDim usage bcs vn lvl dim = do +linkVarToDim usage bcs vn lvl e = do constraints <- getConstraints - case dim of - NamedSize dim' - | Just (dim_lvl, c) <- qualLeaf dim' `M.lookup` constraints, - dim_lvl > lvl -> + mapM_ (checkVar constraints) $ M.keys $ unFV $ freeInExp e + + modifyConstraints $ M.insert vn (lvl, Size (Just e) usage) + where + checkVar constraints dim' + | Just (dim_lvl, c) <- dim' `M.lookup` constraints, + dim_lvl > lvl = case c of ParamSize {} -> do - notes <- dimNotes usage dim + notes <- dimNotes usage e unifyError usage notes bcs $ "Cannot unify size variable" - <+> dquotes (pretty dim') + <+> dquotes (pretty e) <+> "with" <+> dquotes (prettyName vn) <+> "(scope violation)." "This is because" - <+> dquotes (pretty dim') + <+> dquotes (pretty $ qualName dim') <+> "is rigidly bound in a deeper scope." - _ -> modifyConstraints $ M.insert (qualLeaf dim') (lvl, c) - _ -> pure () - - modifyConstraints $ M.insert vn (lvl, Size (Just dim) usage) + _ -> modifyConstraints $ M.insert dim' (lvl, c) + checkVar _ dim' + | vn == dim' = do + notes <- dimNotes usage e + unifyError usage notes bcs $ + "Occurs check: cannot instantiate" + <+> dquotes (prettyName vn) + <+> "with" + <+> dquotes (pretty e) + <+> "." + checkVar _ _ = pure () -- | Assert that this type must be one of the given primitive types. mustBeOneOf :: MonadUnify m => [PrimType] -> Usage -> StructType -> m () @@ -1035,7 +1012,7 @@ arrayElemType usage desc = unifySharedFields :: MonadUnify m => - UnifyDims m -> + UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -1048,7 +1025,7 @@ unifySharedFields onDims usage bound bcs fs1 fs2 = unifySharedConstructors :: MonadUnify m => - UnifyDims m -> + UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -1108,7 +1085,7 @@ mustHaveConstr usage c t fs = do mustHaveFieldWith :: MonadUnify m => - UnifyDims m -> + UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> @@ -1155,7 +1132,7 @@ mustHaveField :: Name -> PatType -> m PatType -mustHaveField usage = mustHaveFieldWith (unifyDims usage) usage mempty noBreadCrumbs +mustHaveField usage = mustHaveFieldWith (unifySizes usage) usage mempty noBreadCrumbs newDimOnMismatch :: (Monoid as, MonadUnify m) => @@ -1167,7 +1144,7 @@ newDimOnMismatch loc t1 t2 = do (t, seen) <- runStateT (matchDims onDims t1 t2) mempty pure (t, M.elems seen) where - r = Rigid $ RigidCond (toStruct t1) (toStruct t2) + r = RigidCond (toStruct t1) (toStruct t2) onDims _ d1 d2 | d1 == d2 = pure d1 | otherwise = do @@ -1175,11 +1152,11 @@ newDimOnMismatch loc t1 t2 = do -- same new size. maybe_d <- gets $ M.lookup (d1, d2) case maybe_d of - Just d -> pure $ NamedSize $ qualName d + Just d -> pure $ sizeFromName (qualName d) loc Nothing -> do - d <- lift $ newDimVar loc r "differ" + d <- lift $ newRigidDim loc r "differ" modify $ M.insert (d1, d2) d - pure $ NamedSize $ qualName d + pure $ sizeFromName (qualName d) loc -- | Like unification, but creates new size variables where mismatches -- occur. Returns the new dimensions thus created. @@ -1226,11 +1203,15 @@ instance MonadUnify UnifyM where modifyConstraints $ M.insert v (0, NoConstraint Lifted $ Usage Nothing loc) pure $ Scalar $ TypeVar mempty Nonunique (qualName v) [] - newDimVar loc rigidity name = do + newDimVar usage rigidity name = do dim <- newVar name case rigidity of - Rigid src -> modifyConstraints $ M.insert dim (0, UnknowableSize loc src) - Nonrigid -> modifyConstraints $ M.insert dim (0, Size Nothing $ Usage Nothing loc) + Rigid src -> + modifyConstraints $ + M.insert dim (0, UnknowableSize (srclocOf usage) src) + Nonrigid -> + modifyConstraints $ + M.insert dim (0, Size Nothing usage) pure dim curLevel = pure 0 diff --git a/tests/badentry10.fut b/tests/badentry10.fut new file mode 100644 index 0000000000..dd29a3121e --- /dev/null +++ b/tests/badentry10.fut @@ -0,0 +1,7 @@ +-- == +-- input { 1i64 [1,2] } +-- output { 1i64 } +-- compiled input { 1i64 [1,2,3] } +-- error: invalid size + +entry main (x: i64) (_: [x+1]i32) = x diff --git a/tests/badentry9.fut b/tests/badentry9.fut new file mode 100644 index 0000000000..e22ad72042 --- /dev/null +++ b/tests/badentry9.fut @@ -0,0 +1,5 @@ +-- Entry points must use all sizes constructively. +-- == +-- error: \[x\].*constructive + +entry main [x] (_: [x+1]i32) = x diff --git a/tests/binding-warn0.fut b/tests/binding-warn0.fut new file mode 100644 index 0000000000..b00427b60f --- /dev/null +++ b/tests/binding-warn0.fut @@ -0,0 +1,6 @@ +-- It is bad to give an argument with a binding that is used in size +-- but it is accepted +-- == +-- warning: with binding + +def f [n] (ns: *[n]i64) = iota (let m = n+2 in m*m) diff --git a/tests/binding-warn1.fut b/tests/binding-warn1.fut new file mode 100644 index 0000000000..4f9a5e93a5 --- /dev/null +++ b/tests/binding-warn1.fut @@ -0,0 +1,7 @@ +-- Bad to bind in slices, but accepted +-- == +-- warning: with binding + +def main (n:i64) (xs:*[n]i64) = + let t = iota n + in t[:let m = n-4 in m*m/n] diff --git a/tests/entryexpr.fut b/tests/entryexpr.fut new file mode 100644 index 0000000000..053909022d --- /dev/null +++ b/tests/entryexpr.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1,2] [3] } +-- output { [1,2,3] } + +def main [n] (xs: [n*2]i32) (ys: [n]i32) = xs ++ ys diff --git a/tests/flattening/LoopInvReshape.fut b/tests/flattening/LoopInvReshape.fut index 4f365889df..eb019b50d1 100644 --- a/tests/flattening/LoopInvReshape.fut +++ b/tests/flattening/LoopInvReshape.fut @@ -11,6 +11,6 @@ def main [n][m] (xs: [m]i32, ys: [n]i64, zs: [n]i64, is: [n]i32, js: [n]i32): []i32 = map (\(y: i64, z: i64, i: i32, j: i32): i32 -> #[unsafe] - let tmp = unflatten y z xs + let tmp = unflatten y z (xs :> [y*z]i32) in tmp[i,j] ) (zip4 ys zs is js) diff --git a/tests/flattening/map-nested-free.fut b/tests/flattening/map-nested-free.fut new file mode 100644 index 0000000000..0c569bb180 --- /dev/null +++ b/tests/flattening/map-nested-free.fut @@ -0,0 +1,9 @@ +-- == +-- input { [5i64,7i64] [5i64,7i64] } +-- output { [5i64, 7i64] } + +def main = map2 (\n x -> + #[unsafe] + let A = #[opaque] replicate n x + let B = #[opaque] map (\i -> A[i%x]) (iota n) + in B[0]) diff --git a/tests/fusion/Vers2.0/hindrReshape0.fut b/tests/fusion/Vers2.0/hindrReshape0.fut index 093408d5e6..bfc453be70 100644 --- a/tests/fusion/Vers2.0/hindrReshape0.fut +++ b/tests/fusion/Vers2.0/hindrReshape0.fut @@ -6,7 +6,7 @@ -- [1, 2, 3, 4, 5, 6, 7, 8, 9] -- [[2, 4, 6], [8, 10, 12], [14, 16, 18]] -- } -def main (orig: [9]i32): ([]i32,[][]i32) = +def main (orig: [3*3]i32): ([]i32,[][]i32) = let a = map (+1) orig let b = unflatten 3 3 a let c = map (\(row: []i32) -> diff --git a/tests/fusion/fuse-across-reshape-transpose.fut b/tests/fusion/fuse-across-reshape-transpose.fut index d3550278d3..eb7a13f445 100644 --- a/tests/fusion/fuse-across-reshape-transpose.fut +++ b/tests/fusion/fuse-across-reshape-transpose.fut @@ -6,8 +6,7 @@ -- } -- structure { /Screma 1 } def main: [][]i32 = - let n = 9 - let a = map (+1) (map i32.i64 (iota(n))) + let a = map (+1) (map i32.i64 (iota(3*3))) let b = unflatten 3 3 a let c = transpose b in map (\(row: []i32) -> diff --git a/tests/fusion/fuse-across-reshape1.fut b/tests/fusion/fuse-across-reshape1.fut index a2beb24a43..8a3325aa1d 100644 --- a/tests/fusion/fuse-across-reshape1.fut +++ b/tests/fusion/fuse-across-reshape1.fut @@ -8,8 +8,7 @@ -- /Screma 1 -- } def main: [][]i32 = - let n = 9 - let a = map (+1) (map i32.i64 (iota(n))) + let a = map (+1) (map i32.i64 (iota(3*3))) let b = unflatten 3 3 a in map (\(row: []i32) -> map (\(x: i32): i32 -> x*2) row) b diff --git a/tests/fusion/fuse-across-reshape2.fut b/tests/fusion/fuse-across-reshape2.fut index 0223c79bd5..0c40b82870 100644 --- a/tests/fusion/fuse-across-reshape2.fut +++ b/tests/fusion/fuse-across-reshape2.fut @@ -5,9 +5,8 @@ -- [[0, 9, 18], [27, 36, 45], [54, 63, 72]] -- } def main: [][]i32 = - let n = 9 - let a = map (\i -> replicate n (i32.i64 i)) - (iota n) + let a = map (\i -> replicate 9 (i32.i64 i)) + (iota (3*3)) let b = unflatten_3d 3 3 9 (flatten a) in map (\(row: [][]i32) -> map (\(x: []i32): i32 -> reduce (+) 0 x) row) b diff --git a/tests/fusion/fuse-across-reshape3.fut b/tests/fusion/fuse-across-reshape3.fut index fc91edae7d..f1d866a9fa 100644 --- a/tests/fusion/fuse-across-reshape3.fut +++ b/tests/fusion/fuse-across-reshape3.fut @@ -4,4 +4,4 @@ def main(n: i64, m: i64, k: i64): [][][]f32 = map (\(ar: [][]f32): [m][n]f32 -> map (\(arr: []f32): [n]f32 -> scan (+) 0f32 arr) ar) ( - unflatten_3d k m n (map f32.i64 (iota(n*m*k)))) + unflatten_3d k m n (map f32.i64 (iota(k*m*n)))) diff --git a/tests/higher-order-functions/alias0.fut b/tests/higher-order-functions/alias0.fut index f1e7be2be2..393e0d44f3 100644 --- a/tests/higher-order-functions/alias0.fut +++ b/tests/higher-order-functions/alias0.fut @@ -2,4 +2,4 @@ -- generation. def main (w: i64) (h: i64) = - [1,2,3] |> unflatten w h + ([1,2,3] :> [w*h]i32) |> unflatten w h diff --git a/tests/higher-order-functions/shape-params6.fut b/tests/higher-order-functions/shape-params6.fut new file mode 100644 index 0000000000..183220f995 --- /dev/null +++ b/tests/higher-order-functions/shape-params6.fut @@ -0,0 +1,9 @@ +-- == +-- input { [1,2] [3,4,5] } +-- output { 5i64 } + +def f [n][m] (f: [n+m]i32 -> i64) (a: [n]i32) (b: [m]i32) = f (a ++ b) + +def g n m (_: [n+m]i32) = n+m + +def main = f (g 2 3) diff --git a/tests/higher-order-functions/shape-params7.fut b/tests/higher-order-functions/shape-params7.fut new file mode 100644 index 0000000000..361fb1f0b5 --- /dev/null +++ b/tests/higher-order-functions/shape-params7.fut @@ -0,0 +1,11 @@ +-- == +-- input { [1,2,3] } +-- output { [1,2,3] } + +def inc (x: i64) = x + 1 + +def tail [n] 't (A: [inc n]t) = A[1:] :> [n]t + +def cons [n] 't (x: t) (A: [n]t): [inc n]t = [x] ++ A :> [inc n]t + +def main (xs: []i32) = tail (cons 2 xs) diff --git a/tests/issue1112.fut b/tests/issue1112.fut index 28873f4aec..df25f4b2e8 100644 --- a/tests/issue1112.fut +++ b/tests/issue1112.fut @@ -42,7 +42,7 @@ def solveAB [m][n] (A:[m][m]f32) (B:[m][n]f32) : [m][n]f32 = in AB[0:m, m:(m+n)] :> [m][n]f32 def solveAb [m] (A:[m][m]f32) (b:[m]f32) = - unflatten m 1 b |> solveAB A |> flatten_to m + unflatten m 1 (b :> [m*1]f32) |> solveAB A |> flatten_to m def main u_bs (points':[]v3) (forces:[](v3,v3)) = let C (x,y,z) = diff --git a/tests/issue1237.fut b/tests/issue1237.fut index aaf4f65779..3a374e031a 100644 --- a/tests/issue1237.fut +++ b/tests/issue1237.fut @@ -14,24 +14,22 @@ def mkVal [l2][l] (B: i64) (y:i32) (x:i32) (pen:i32) (inp_l:[l2]i32) (ref_l:[l][ , ( (inp_l[fInd B (y-1) x])) - pen ) -def intraBlockPar [lensq][len] (B: i64) - (penalty: i32) - (inputsets: [lensq]i32) - (reference2: [len][len]i32) - (b_y: i64) (b_x: i64) - : [B][B]i32 = +def intraBlockPar [len] (B: i64) + (penalty: i32) + (inputsets: [len*len]i32) + (reference2: [len][len]i32) + (b_y: i64) (b_x: i64) + : [B][B]i32 = let ref_l = reference2[b_y * B + 1: b_y * B + 1 + B, b_x * B + 1: b_x * B + 1 + B] :> [B][B]i32 let inputsets' = unflatten len len inputsets - let B1 = B+1 - - let inp_l' = (copy inputsets'[b_y * B : b_y * B + B + 1, b_x * B : b_x * B + B + 1]) :> *[B1][B1]i32 + let inp_l' = (copy inputsets'[b_y * B : b_y * B + B + 1, b_x * B : b_x * B + B + 1]) :> *[B+1][B+1]i32 -- inp_l is the working memory - let inp_l = replicate (B1*B1) 0i32 - |> unflatten B1 B1 + let inp_l = replicate ((B+1)*(B+1)) 0i32 + |> unflatten (B+1) (B+1) -- Initialize inp_l with the already processed the column to the left of this -- block @@ -85,18 +83,20 @@ def updateBlocks [q][lensq] (B: i64) def main [lensq] (penalty : i32) (inputsets : *[lensq]i32) (reference : *[lensq]i32) : *[lensq]i32 = - let len = i32.f32 (f32.sqrt (f32.i64 lensq)) + let len = i64.f32 (f32.sqrt (f32.i64 lensq)) + let inputsets = inputsets :> [len*len]i32 + let reference = reference :> [len*len]i32 let worksize = len - 1 - let B = i64.min (i64.i32 worksize) B0 + let B = i64.min worksize B0 - let B = assert (i64.i32 worksize % B == 0) B + let B = assert (worksize % B == 0) B - let block_width = trace <| worksize / i32.i64 B - let reference2 = unflatten (i64.i32 len) (i64.i32 len) reference + let block_width = trace <| worksize / B + let reference2 = unflatten len len reference let inputsets = loop inputsets for blk < block_width do - let blk = i64.i32 (blk + 1) + let blk = blk + 1 let block_inp = tabulate blk (\b_x -> let b_y = blk-1-b_x @@ -105,6 +105,6 @@ def main [lensq] (penalty : i32) let mkBY bx = i32.i64 (blk - 1) - bx let mkBX bx = bx - in updateBlocks B len blk mkBY mkBX block_inp inputsets + in updateBlocks B (i32.i64 len) blk mkBY mkBX block_inp inputsets - in inputsets + in inputsets :> [lensq]i32 diff --git a/tests/issue1239.fut b/tests/issue1239.fut index 70fd420c3d..389a79f507 100644 --- a/tests/issue1239.fut +++ b/tests/issue1239.fut @@ -1,7 +1,7 @@ def n = 2i64 def grid (i: i64): [n][n]i64 = - let grid = unflatten n n (0.. unflatten (8**(ii)) 8 |> map restrictCell + let xx = (vx :> [(8**ii)*8][24]f64) |> unflatten (8**(ii)) 8 |> map restrictCell in xx let coarseX = flatten_to 24 coarseValues in coarseX diff --git a/tests/issue1435.fut b/tests/issue1435.fut index f8a22b2e9c..a0e09f9cac 100644 --- a/tests/issue1435.fut +++ b/tests/issue1435.fut @@ -1,5 +1,4 @@ -- == --- error: artificial size def segmented_scan [n] 't (op: t -> t -> t) (ne: t) (flags: [n]bool) (as: [n]t): [n]t = diff --git a/tests/issue245.fut b/tests/issue245.fut index 75f80fe399..c28383b564 100644 --- a/tests/issue245.fut +++ b/tests/issue245.fut @@ -14,7 +14,7 @@ def reshape_int (l: i64) (x: []i32): []i32 = let (v1, _) = split (l) (extend) in v1 entry main (x: i64) (y: i64): [][]i32 = - let t_v1 = unflatten x y (reshape_int ((x * (y * 1))) (map (\x -> + let t_v1 = unflatten x y (reshape_int ((x * y)) (map (\x -> (i32.i64 x + 1)) (iota (6)))) in let t_v2 = transpose (t_v1) in let t_v3 = take_arrint (x) (t_v2) in diff --git a/tests/issue246.fut b/tests/issue246.fut index 7b493f0479..42e8a40f92 100644 --- a/tests/issue246.fut +++ b/tests/issue246.fut @@ -26,7 +26,7 @@ def reshape_int (l: i64) (x: []i64): []i64 = let (v1, _) = split (l) (extend) in v1 entry main (n: i64) (m: i64): []i64 = - let t_v1 = unflatten n m (reshape_int ((n * (m * 1))) ((map (\(x: i64): i64 -> + let t_v1 = unflatten n m (reshape_int ((n * m)) ((map (\(x: i64): i64 -> (x + 1)) (iota (12))))) in let t_v2 = transpose (t_v1) in let t_v3 = take_arrint (2) (t_v2) in diff --git a/tests/issue869.fut b/tests/issue869.fut index 85d762d87a..dec2ab8d4b 100644 --- a/tests/issue869.fut +++ b/tests/issue869.fut @@ -1,7 +1,7 @@ -- The "fix" for this in the internaliser was actually a workaround -- for a type checker bug (#1565). -- == --- error: Loop body does not have expected type +-- error: Initial loop values do not have expected type. def matmult [n][m][p] (x: [n][m]f32) (y: [m][p]f32) : [n][p]f32 = map (\xr -> diff --git a/tests/loops/pow2reduce.fut b/tests/loops/pow2reduce.fut new file mode 100644 index 0000000000..471aec65d0 --- /dev/null +++ b/tests/loops/pow2reduce.fut @@ -0,0 +1,17 @@ +-- Tree reduction that only works on input that is a power of two. +-- == +-- input { [1,2,3,4] } +-- output { 10 } + +def step [k] (xs: [2**k]i32) : [2**(k-1)]i32 = + tabulate (2**(k-1)) (\i -> xs[i*2] + xs[i*2+1]) + +def sum [k] (xs: [2**k]i32) : i32 = + head (loop xs for i in reverse (iota k) do + step (xs :> [2**(i+1)]i32)) + +def ilog2 (n: i64) : i64 = i64.i32 (63 - i64.clz n) + +def main [n] (xs: [n]i32) = + let k = ilog2 n + in sum (xs :> [2**k]i32) diff --git a/tests/memory-block-merging/coalescing/concat/neg0.fut b/tests/memory-block-merging/coalescing/concat/neg0.fut index b8db9acfca..9fee4bf396 100644 --- a/tests/memory-block-merging/coalescing/concat/neg0.fut +++ b/tests/memory-block-merging/coalescing/concat/neg0.fut @@ -13,6 +13,6 @@ let main [n] (ns: [n]i32) (i: i32): [][]i32 = let t0 = map (+ 1) ns -- Will use the memory of t3. let t1 = map (* 2) ns -- Will use the memory of t3. let t2 = map (/ 3) ns -- Will use the memory of t3. - let t3 = unflatten n (n * 3) (replicate (n * n * 3) 0) + let t3 = unflatten n (n * 3) (replicate (n * (n * 3)) 0) let t3[i] = concat_to (n * 3) (concat t0 t1) t2 in t3 diff --git a/tests/memory-block-merging/coalescing/concat/pos1.fut b/tests/memory-block-merging/coalescing/concat/pos1.fut index 817bafee7d..976a2cf789 100644 --- a/tests/memory-block-merging/coalescing/concat/pos1.fut +++ b/tests/memory-block-merging/coalescing/concat/pos1.fut @@ -11,7 +11,7 @@ -- structure gpu-mem { Alloc 1 } let main [n] (ns: [n]i32) (i: i64): [][]i32 = - let t3 = unflatten n (n * 3) (replicate (n * n * 3) 0) + let t3 = unflatten n (n * 3) (replicate (n * (n * 3)) 0) let t0 = map (+ 1) ns -- Will use the memory of t3. let t1 = map (* 2) ns -- Will use the memory of t3. let t2 = map (/ 3) ns -- Will use the memory of t3. diff --git a/tests/memory-block-merging/coalescing/lud/lud.fut b/tests/memory-block-merging/coalescing/lud/lud.fut index 026889eef7..f5b7e93ce3 100644 --- a/tests/memory-block-merging/coalescing/lud/lud.fut +++ b/tests/memory-block-merging/coalescing/lud/lud.fut @@ -32,7 +32,7 @@ def lud_diagonal [b] (a: [b][b]f32): *[b][b]f32 = let mat[i+1] = row in mat - ) (unflatten (opaque 1) b a) + ) (unflatten (opaque 1) b (a :> [opaque 1*b][b]f32)) |> head def lud_perimeter_upper [m][b] (diag: [b][b]f32) (a0s: [m][b][b]f32): *[m][b][b]f32 = diff --git a/tests/modules/ascription-error7.fut b/tests/modules/ascription-error7.fut new file mode 100644 index 0000000000..1ec3fb3c8d --- /dev/null +++ b/tests/modules/ascription-error7.fut @@ -0,0 +1,8 @@ +-- == +-- error: constructive + +module m : { + type sum [n][m] +} = { + type sum [n][m] = [n+m]bool +} diff --git a/tests/modules/ascription15.fut b/tests/modules/ascription15.fut new file mode 100644 index 0000000000..7a390d7c13 --- /dev/null +++ b/tests/modules/ascription15.fut @@ -0,0 +1,7 @@ +module type mt = { + type sum [n][m] = ([n]bool, [m]bool, [n+m]bool) +} + +module m : mt = { + type sum [n][m] = ([n]bool, [m]bool, [n+m]bool) +} diff --git a/tests/modules/sizeparams8.fut b/tests/modules/sizeparams8.fut new file mode 100644 index 0000000000..59d381fa22 --- /dev/null +++ b/tests/modules/sizeparams8.fut @@ -0,0 +1,5 @@ +module meta: { + val plus_comm [a][b]'t : [a+b]t -> [b+a]t +} = { + def plus_comm [a][b]'t (xs: [a+b]t): [b+a]t = xs :> [b+a]t +} diff --git a/tests/modules/sizes2.fut b/tests/modules/sizes2.fut index f4ff59e277..7a1d4e3628 100644 --- a/tests/modules/sizes2.fut +++ b/tests/modules/sizes2.fut @@ -1,5 +1,5 @@ -- == --- error: Dimensions "n" +-- error: Sizes "n" module type withvec_mt = { val n : i64 diff --git a/tests/operator/size-section0.fut b/tests/operator/size-section0.fut new file mode 100644 index 0000000000..2e56061eff --- /dev/null +++ b/tests/operator/size-section0.fut @@ -0,0 +1,4 @@ +-- Check that sizes are well calculated in left section + +def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = + [xs ++ ys] ++ map (xs ++) mat diff --git a/tests/operator/size-section1.fut b/tests/operator/size-section1.fut new file mode 100644 index 0000000000..03640d3220 --- /dev/null +++ b/tests/operator/size-section1.fut @@ -0,0 +1,4 @@ +-- Check that sizes are well calculated in right section + +def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][n]i64) = + [xs ++ ys] ++ map (++ ys) mat diff --git a/tests/operator/size-section2.fut b/tests/operator/size-section2.fut new file mode 100644 index 0000000000..db66c7e76a --- /dev/null +++ b/tests/operator/size-section2.fut @@ -0,0 +1,5 @@ +-- Check that sizes are well calculated in left section, even with bounded existential sizes + +def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = + let xs' = filter (>0) xs + in [xs' ++ ys] ++ map (xs' ++) mat diff --git a/tests/operator/size-section3.fut b/tests/operator/size-section3.fut new file mode 100644 index 0000000000..07dbc624b1 --- /dev/null +++ b/tests/operator/size-section3.fut @@ -0,0 +1,4 @@ +-- Check that sizes are well calculated in left section, with complex sizes +-- == +def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = + [(xs ++ xs) ++ ys] ++ map (xs ++ xs ++) mat diff --git a/tests/operator/size-section4.fut b/tests/operator/size-section4.fut new file mode 100644 index 0000000000..7651d85557 --- /dev/null +++ b/tests/operator/size-section4.fut @@ -0,0 +1,4 @@ +-- Check that sizes are well calculated in right section, with complex sizes +-- == +def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = + [ys ++ (xs ++ xs)] ++ map (++ xs ++ xs) mat diff --git a/tests/replicate3.fut b/tests/replicate3.fut index ea4ac5f437..47d457ac5f 100644 --- a/tests/replicate3.fut +++ b/tests/replicate3.fut @@ -5,5 +5,5 @@ def main [n] (b: [n]i32, m: i64) = let x = n * m let c = b :> [x]i32 - let d = replicate 10 c + let d = replicate (2*5*(n*m)) c in unflatten_3d 2 5 (n*m) d diff --git a/tests/reshape1.fut b/tests/reshape1.fut index e5e18426db..7a6aeaaff5 100644 --- a/tests/reshape1.fut +++ b/tests/reshape1.fut @@ -6,11 +6,11 @@ -- [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64], [7i64, 8i64, 9i64]] -- } -- input { [1i64,2i64,3i64] } --- error: Cannot unflatten array of shape \[3\] to array of shape \[1\]\[1\] +-- error: (3)*cannot match shape.*\[1\]i64 def intsqrt(x: i64): i64 = i64.f32(f32.sqrt(f32.i64(x))) def main [n] (a: [n]i64): [][]i64 = - unflatten (intsqrt n) (intsqrt n) a + unflatten (intsqrt n) (intsqrt n) (a :> [intsqrt n*intsqrt n]i64) diff --git a/tests/rosettacode/md5.fut b/tests/rosettacode/md5.fut index 0b0890d172..bc2afd86a1 100644 --- a/tests/rosettacode/md5.fut +++ b/tests/rosettacode/md5.fut @@ -53,7 +53,7 @@ def unbytes(bs: [4]u8): u32 = u32.u8(bs[2]) * 0x10000u32 + u32.u8(bs[3]) * 0x1000000u32 -def unbytes_block(block: [64]u8): [16]u32 = +def unbytes_block(block: [16*4]u8): [16]u32 = map unbytes (unflatten 16 4 block) -- Process 512 bits of the input. @@ -82,10 +82,12 @@ def md5 [n] (ms: [n][16]u32): md5 = def main [n] (ms: [n]u8): [16]u8 = let padding = 64 - (n % 64) let n_padded = n + padding + let num_blocks = n_padded / 64 let ms_padded = ms ++ bytes 0x80u32 ++ replicate (padding-12) 0x0u8 ++ bytes (u32.i64(n*8)) ++ [0u8,0u8,0u8,0u8] - let (a,b,c,d) = md5 (map unbytes_block (unflatten (n_padded / 64) 64 ms_padded)) + :> [num_blocks*(16*4)]u8 + let (a,b,c,d) = md5 (map unbytes_block (unflatten num_blocks (16*4) ms_padded)) in flatten (map bytes [a,b,c,d]) :> [16]u8 diff --git a/tests/scatter/nw.fut b/tests/scatter/nw.fut index 9cb3ee3838..3bf4148080 100644 --- a/tests/scatter/nw.fut +++ b/tests/scatter/nw.fut @@ -15,12 +15,12 @@ def mkVal [l2][l] (y:i32) (x:i32) (pen:i32) (inp_l:[l2][l2]i32) (ref_l:[l][l]i32 , ( (inp_l[y - 1, x])) - pen ) -def intraBlockPar [lensq][len] (B: i64) - (penalty: i32) - (inputsets: [lensq]i32) - (reference2: [len][len]i32) - (b_y: i64) (b_x: i64) - : [B][B]i32 = +def intraBlockPar [len] (B: i64) + (penalty: i32) + (inputsets: [len*len]i32) + (reference2: [len][len]i32) + (b_y: i64) (b_x: i64) + : [B][B]i32 = let ref_l = reference2[b_y * B + 1: b_y * B + 1 + B, b_x * B + 1: b_x * B + 1 + B] :> [B][B]i32 @@ -47,12 +47,12 @@ def intraBlockPar [lensq][len] (B: i64) in inp_l[1:B+1,1:B+1] :> [B][B]i32 -def updateBlocks [q][lensq] (B: i64) - (len: i32) (blk: i64) - (mk_b_y: (i32 -> i32)) - (mk_b_x: (i32 -> i32)) - (block_inp: [q][B][B]i32) - (inputsets: *[lensq]i32) = +def updateBlocks [q] (B: i64) + (len: i64) (blk: i64) + (mk_b_y: (i32 -> i32)) + (mk_b_x: (i32 -> i32)) + (block_inp: [q][B][B]i32) + (inputsets: *[len*len]i32) = let (inds, vals) = unzip ( tabulate (blk*B*B) (\gid -> let B2 = i32.i64 (B*B) @@ -63,7 +63,7 @@ def updateBlocks [q][lensq] (B: i64) let b_y = mk_b_y bx let b_x = mk_b_x bx let v = #[unsafe] block_inp[bx, ty, tx] - let ind = (i32.i64 B*b_y + 1 + ty) * len + (i32.i64 B*b_x + tx + 1) + let ind = (i32.i64 B*b_y + 1 + ty) * i32.i64 len + (i32.i64 B*b_x + tx + 1) in (i64.i32 ind, v))) in scatter inputsets inds vals @@ -71,20 +71,21 @@ def updateBlocks [q][lensq] (B: i64) def main [lensq] (penalty : i32) (inputsets : *[lensq]i32) (reference : *[lensq]i32) : *[lensq]i32 = - let len = i32.f32 (f32.sqrt (f32.i64 lensq)) + let len = i64.f32 (f32.sqrt (f32.i64 lensq)) let worksize = len - 1 - let B = i64.min (i64.i32 worksize) B0 + let B = i64.min worksize B0 -- worksize should be a multiple of B0 - let B = assert (i64.i32 worksize % B == 0) B + let B = assert (worksize % B == 0) B - let block_width = trace <| worksize / i32.i64 B - let reference2 = unflatten (i64.i32 len) (i64.i32 len) reference + let block_width = trace <| worksize / B + let reference2 = unflatten len len (reference :> [len*len]i32) + let inputsets = (inputsets :> [len*len]i32) -- First anti-diagonal half of the entire input matrix let inputsets = loop inputsets for blk < block_width do - let blk = i64.i32 (blk + 1) + let blk = blk + 1 let block_inp = -- Process an anti-diagonal of independent blocks tabulate blk (\b_x -> @@ -96,4 +97,4 @@ def main [lensq] (penalty : i32) let mkBX bx = bx in updateBlocks B len blk mkBY mkBX block_inp inputsets - in inputsets + in inputsets :> [lensq]i32 diff --git a/tests/shapes/error4.fut b/tests/shapes/error4.fut index 4454afb4a4..b842bdf44a 100644 --- a/tests/shapes/error4.fut +++ b/tests/shapes/error4.fut @@ -1,6 +1,6 @@ -- We cannot just ignore constraints imposed by a higher-order function. -- == --- error: Dimensions.*"n".*do not match +-- error: Sizes.*"n".*do not match def f (g: (n: i64) -> [n]i32) (l: i64): i32 = (g l)[0] diff --git a/tests/shapes/error5.fut b/tests/shapes/error5.fut index bbe7d92516..00aa455b67 100644 --- a/tests/shapes/error5.fut +++ b/tests/shapes/error5.fut @@ -1,6 +1,6 @@ -- A function 'a -> a' must be size-preserving. -- == --- error: do not match +-- error: Occurs check def ap 'a (f: a -> a) (x: a) = f x diff --git a/tests/shapes/existential-argument.fut b/tests/shapes/existential-argument.fut new file mode 100644 index 0000000000..3895523f34 --- /dev/null +++ b/tests/shapes/existential-argument.fut @@ -0,0 +1,9 @@ +-- Sizes obtained with existantialy bounded sizes should not be calculated +-- == +-- input { 2i64 } output { [0i64, 1i64, 0i64, 1i64] } + +def double_eval 't (f : () -> []t) : []t = + f () ++ f () + +def main (n:i64) : []i64 = + double_eval (\_ -> iota n) diff --git a/tests/shapes/field-in-size.fut b/tests/shapes/field-in-size.fut new file mode 100644 index 0000000000..7e17a95a9b --- /dev/null +++ b/tests/shapes/field-in-size.fut @@ -0,0 +1,4 @@ +-- Allow to access argument field as size for return type +-- == + +def f (p: {a:i64,b:bool}) : [p.a]i64 = iota p.a diff --git a/tests/shapes/funshape3.fut b/tests/shapes/funshape3.fut index 13527eb151..0814b27ae6 100644 --- a/tests/shapes/funshape3.fut +++ b/tests/shapes/funshape3.fut @@ -1,5 +1,5 @@ -- == --- error: Causality check +-- input { 5i64 } output { 7i64 } def f [n] (_: [n]i64) (_: [n]i64 -> i32, _: [n]i64) = n diff --git a/tests/shapes/funshape5.fut b/tests/shapes/funshape5.fut index 6c36167b3a..c1c5b9ede9 100644 --- a/tests/shapes/funshape5.fut +++ b/tests/shapes/funshape5.fut @@ -1,5 +1,5 @@ -- == --- error: Causality check +-- error: Entry point functions may not be polymorphic def main indices (cs: *[](i32,i32)) j = map (\k -> (indices[j],k)) <| drop (j+1) indices diff --git a/tests/shapes/funshape6.fut b/tests/shapes/funshape6.fut index c18b8db9f0..c2570c2e20 100644 --- a/tests/shapes/funshape6.fut +++ b/tests/shapes/funshape6.fut @@ -1,5 +1,5 @@ -- Based on issue 1351. -- == --- error: Causality +-- input { [[1.0,2.0,3.0],[4.0,5.0,6.0]] 0i64 4i64 } def main (xs: [][]f64) i j = (.[i:j]) <| iota (i+j) diff --git a/tests/shapes/implicit-shape-use.fut b/tests/shapes/implicit-shape-use.fut index 134e4fe16d..5dda406146 100644 --- a/tests/shapes/implicit-shape-use.fut +++ b/tests/shapes/implicit-shape-use.fut @@ -52,11 +52,10 @@ def brownianBridgeDates [num_dates] in bbrow def brownianBridge [num_dates] - (num_und: i64, - bb_inds: [3][num_dates]i32, - bb_data: [3][num_dates]f64, - gaussian_arr: []f64 - ) = + (num_und: i64) + (bb_inds: [3][num_dates]i32) + (bb_data: [3][num_dates]f64) + (gaussian_arr: []f64) = let gauss2d = unflatten num_dates num_und gaussian_arr let gauss2dT = transpose gauss2d in transpose ( @@ -65,11 +64,9 @@ def brownianBridge [num_dates] def main [num_dates] (num_und: i64) (bb_inds: [3][num_dates]i32) - (arr_usz: []f64): [][]f64 = - let n = num_dates*num_und - let arr = arr_usz :> [n]f64 + (arr: [num_dates*num_und]f64): [][]f64 = let bb_data= map (\(row: []i32) -> map f64.i32 row ) (bb_inds ) - let bb_mat = brownianBridge( num_und, bb_inds, bb_data, arr ) + let bb_mat = brownianBridge num_und bb_inds bb_data arr in bb_mat diff --git a/tests/shapes/modules1.fut b/tests/shapes/modules1.fut index babb751b2c..1e8ddc345b 100644 --- a/tests/shapes/modules1.fut +++ b/tests/shapes/modules1.fut @@ -1,7 +1,7 @@ -- It is not allowed to create an opaque type whose size parameters -- are not used in array dimensions. -- == --- error: "n" +-- error: is not used constructively module m = { type^ t [n] = [n]i32 -> i64 diff --git a/tests/shapes/opaque0.fut b/tests/shapes/opaque0.fut new file mode 100644 index 0000000000..4c2baa7b18 --- /dev/null +++ b/tests/shapes/opaque0.fut @@ -0,0 +1,21 @@ +-- == +-- error: do not match + +module num: { + type t[n] + val mk : (x: i64) -> t[x] + val un [n] : t[n] -> i64 + val comb [n] : t[n] -> t[n] -> i64 +} = { + type t[n] = [n]() + def mk x = replicate x () + def un x = length x + def comb x y = length (zip x y) +} + +def f x = + let y = x + 1 + in num.mk y + +def main a b = + num.comb (f a) (f b) diff --git a/tests/shapes/range0.fut b/tests/shapes/range0.fut index c6596d35e1..4ea5ec15eb 100644 --- a/tests/shapes/range0.fut +++ b/tests/shapes/range0.fut @@ -1,4 +1,4 @@ -- Some ranges have known sizes. -def main (n: i64) : ([n]i64, [n]i64) = - (0.. unflatten 9 3 |> unflatten 3 3 + let vs = iota (3*3*3) |> unflatten (3*3) 3 |> unflatten 3 3 let zs = flat_update_3d xs 17 27 10 1 vs in zs diff --git a/tests/slice-lmads/lud.fut b/tests/slice-lmads/lud.fut index 4a82a638ce..a014f07f9e 100644 --- a/tests/slice-lmads/lud.fut +++ b/tests/slice-lmads/lud.fut @@ -31,7 +31,7 @@ def lud_diagonal [b] (a: [b][b]f32): *[b][b]f32 = let mat[i+1] = row in mat - ) (unflatten (opaque 1) b a) + ) (unflatten (opaque 1) b (a :> [opaque 1*b][b]f32)) |> head def lud_perimeter_upper [m][b] (diag: [b][b]f32) (a0s: [m][b][b]f32): *[m][b][b]f32 = diff --git a/tests/slice-lmads/small.fut b/tests/slice-lmads/small.fut index 6dd00da496..5340eb9c5f 100644 --- a/tests/slice-lmads/small.fut +++ b/tests/slice-lmads/small.fut @@ -29,5 +29,5 @@ entry index_antidiag [n] (xs: [n]i64): [][][]i64 = flat_index_3d xs 2 3 3 2 4 2 1 entry update_antidiag [n] (xs: *[n]i64): *[n]i64 = - let vs = iota (2 * 2 * 2) |> unflatten 4 2 |> unflatten 2 2 + let vs = iota (2*2*2) |> unflatten (2*2) 2 |> unflatten 2 2 in flat_update_3d xs 1 9 4 1 vs diff --git a/tests/slice-lmads/small_4d.fut b/tests/slice-lmads/small_4d.fut index b65e7976f5..01dd3d8e0c 100644 --- a/tests/slice-lmads/small_4d.fut +++ b/tests/slice-lmads/small_4d.fut @@ -29,5 +29,5 @@ entry index_antidiag [n] (xs: [n]i64): [][][][]i64 = flat_index_4d xs 2 2 8 2 2 2 4 2 1 entry update_antidiag [n] (xs: *[n]i64): *[n]i64 = - let vs = iota (2 * 2 * 2 * 2) |> unflatten 8 2 |> unflatten 4 2 |> unflatten 2 2 + let vs = iota (2*2*2*2) |> unflatten (2*2*2) 2 |> unflatten (2*2) 2 |> unflatten 2 2 in flat_update_4d xs 2 8 2 4 1 vs diff --git a/tests/sumtypes/existential-match.fut b/tests/sumtypes/existential-match.fut new file mode 100644 index 0000000000..4e9a2cc68b --- /dev/null +++ b/tests/sumtypes/existential-match.fut @@ -0,0 +1,8 @@ +type~ sumT = #foo []i64 | #bar i64 + +def thing xs : sumT = #foo (filter (>0) xs) + +def main (xs: []i64) : []i64 = + match thing xs + case #foo xs' -> xs ++ xs' + case #bar i -> xs ++ [i] diff --git a/tests/types/badsquare-lam.fut b/tests/types/badsquare-lam.fut index 93c6c9efbe..a199274e7d 100644 --- a/tests/types/badsquare-lam.fut +++ b/tests/types/badsquare-lam.fut @@ -1,5 +1,6 @@ +-- The error here could be better. -- == --- error: Dimensions.*do not match +-- error: scope violation type square [n] 't = [n][n]t diff --git a/tests/types/badsquare.fut b/tests/types/badsquare.fut index f24e62ed01..580a23a03d 100644 --- a/tests/types/badsquare.fut +++ b/tests/types/badsquare.fut @@ -1,5 +1,5 @@ -- == --- error: Dimensions.*do not match +-- error: Sizes.*do not match type square [n] 't = [n][n]t diff --git a/tests/types/ext2.fut b/tests/types/ext2.fut index 766084e5e8..6c3c118ce1 100644 --- a/tests/types/ext2.fut +++ b/tests/types/ext2.fut @@ -1,5 +1,5 @@ -- == --- error: Dimensions .* do not match +-- error: Sizes .* do not match type^ t = ?[n].([n]bool, bool -> [n]bool) diff --git a/tests/types/metasizes.fut b/tests/types/metasizes.fut new file mode 100644 index 0000000000..bdda4dd9aa --- /dev/null +++ b/tests/types/metasizes.fut @@ -0,0 +1,52 @@ +-- A tricky test of type-level programming. +-- == +-- input { [1,2,3] [4,5,6] [7,8,9] } +-- output { [1, 2, 3, 4, 5, 6, 7, 8, 9] +-- [4, 5, 6, 1, 2, 3, 7, 8, 9] +-- } + +module meta: { + type eq[n][m] + + val coerce [n][m]'t : eq[n][m] -> [n]t -> [m]t + val coerce_inner [n][m]'t [k] : eq[n][m] -> [k][n]t -> [k][m]t + + val refl [n] : eq[n][n] + val comm [n][m] : eq[n][m] -> eq[m][n] + val trans [n][m][k] : eq[n][m] -> eq[m][k] -> eq[n][k] + + val plus_comm [a][b] : eq[a+b][b+a] + val plus_assoc [a][b][c] : eq[(a+b)+c][a+(b+c)] + val plus_lhs [a][b][c] : eq[a][b] -> eq[a+c][b+c] + val plus_rhs [a][b][c] : eq[c][b] -> eq[a+c][a+b] + + val mult_comm [a][b] : eq[a*b][b*a] + val mult_assoc [a][b][c] : eq[(a*b)*c][a*(b*c)] + val mult_lhs [a][b][c] : eq[a][b] -> eq[a+c][b+c] + val mult_rhs [a][b][c] : eq[c][b] -> eq[a+c][a+b] +} = { + type eq[n][m] = [0][n][m]() + + def coerce [n][m]'t (_: eq[n][m]) (a: [n]t) = a :> [m]t + def coerce_inner [n][m]'t [k] (_: eq[n][m]) (a: [k][n]t) = a :> [k][m]t + + def refl = [] + def comm _ = [] + def trans _ _ = [] + + def plus_comm = [] + def plus_assoc = [] + def plus_lhs _ = [] + def plus_rhs _ = [] + + def mult_comm = [] + def mult_assoc = [] + def mult_lhs _ = [] + def mult_rhs _ = [] +} + +def main [n][m][l] (xs: [n]i32) (ys: [m]i32) (zs: [l]i32) = + let proof : meta.eq[m+(n+l)][(n+m)+l] = + meta.comm meta.plus_assoc `meta.trans` meta.plus_lhs meta.plus_comm + in zip ((xs ++ ys) ++ zs) (meta.coerce proof (ys ++ (xs ++ zs))) + |> unzip diff --git a/tests/types/sizeparams10.fut b/tests/types/sizeparams10.fut new file mode 100644 index 0000000000..9659e6e50c --- /dev/null +++ b/tests/types/sizeparams10.fut @@ -0,0 +1,11 @@ +-- What about size parameters that are only known in complex expressions? +-- == +-- input { [1,2] [3,4] } +-- output { [3,4,1,2] } + +type eq[n][m] = [n][m]() +def coerce [n][m]'t (_: eq[n][m]) (a: [n]t) = a :> [m]t +def plus_comm [a][b]'t : eq[a+b] [b+a] = tabulate_2d (a+b) (b+a) (\_ _ -> ()) + +def main [n][m] (xs: [n]i32) (ys: [m]i32) = + copy (coerce plus_comm (ys ++ xs)) diff --git a/tests/types/sizeparams11.fut b/tests/types/sizeparams11.fut new file mode 100644 index 0000000000..363186f32b --- /dev/null +++ b/tests/types/sizeparams11.fut @@ -0,0 +1,10 @@ +-- Another complicated case. +-- == +-- input { 1i64 2i64 } +-- output { [[true, true, true], [true, true, true], [true, true, true]] } + +def plus a b : i64 = a + b + +def plus_comm [a][b]'t : [plus a b][plus b a]bool = tabulate_2d (plus a b) (plus b a) (\_ _ -> true) + +def main a b = copy plus_comm : [plus a b][plus b a]bool diff --git a/tests/uniqueness/uniqueness-error62.fut b/tests/uniqueness/uniqueness-error62.fut new file mode 100644 index 0000000000..ecee93beb8 --- /dev/null +++ b/tests/uniqueness/uniqueness-error62.fut @@ -0,0 +1,5 @@ +-- Size expression should be non-consuming +-- == +-- error: "ns".*not consumable +def consume (xs: *[]i64): i64 = xs[0] +def f [n] (ns: *[n]i64) (xs: [consume ns]f32) = xs[0] diff --git a/tests/uniqueness/uniqueness-warn0.fut b/tests/uniqueness/uniqueness-warn0.fut new file mode 100644 index 0000000000..1f5a0fb1b0 --- /dev/null +++ b/tests/uniqueness/uniqueness-warn0.fut @@ -0,0 +1,7 @@ +-- It is bad to give consuming argument that is used in size +-- but it is accepted +-- == +-- warning: with consumption + +def consume (xs: *[]i64): i64 = xs[0] +def f [n] (ns: *[n]i64) = iota (consume ns) diff --git a/tests/uniqueness/uniqueness-warn1.fut b/tests/uniqueness/uniqueness-warn1.fut new file mode 100644 index 0000000000..e70fd69bfe --- /dev/null +++ b/tests/uniqueness/uniqueness-warn1.fut @@ -0,0 +1,9 @@ +-- Bad to consume in slices, but accepted +-- == +-- warning: with consumption + +def consume (xs: *[]i64): i64 = xs[0] + +def main (n:i64) (xs:*[n]i64) = + let t = iota n + in t[:consume xs] diff --git a/tests/uniqueness/uniqueness59.fut b/tests/uniqueness/uniqueness59.fut new file mode 100644 index 0000000000..9441fa4ccc --- /dev/null +++ b/tests/uniqueness/uniqueness59.fut @@ -0,0 +1,4 @@ +-- It is ok to consuming non-free variables +-- == +def consume (xs: *[]i64): i64 = xs[0] +def f [n] (ns: [n]i64) (xs: [consume (iota 10)]f32) = xs[0] diff --git a/tests/uniqueness/uniqueness60.fut b/tests/uniqueness/uniqueness60.fut new file mode 100644 index 0000000000..5af39254bd --- /dev/null +++ b/tests/uniqueness/uniqueness60.fut @@ -0,0 +1,5 @@ +-- If consumption is on bounded var, no problem +-- == +-- warning: ^$ +def consume (xs: *[]i64): i64 = xs[0] +def f (n:i64) = iota (consume (iota n)) diff --git a/tests_repl/issue1347.fut b/tests_repl/issue1347.fut index 70bea702b3..5d215999c9 100644 --- a/tests_repl/issue1347.fut +++ b/tests_repl/issue1347.fut @@ -1,6 +1,6 @@ entry blockify [n] (b: i64) (xs: [n][n]i32) = - xs + (xs :> [(n/b)*b][(n/b)*b]i32) |> unflatten (n / b) b |> map transpose |> map (unflatten (n / b) b) diff --git a/tests_repl/issue1347.in b/tests_repl/issue1347.in index 95e09895a0..7454d072f6 100644 --- a/tests_repl/issue1347.in +++ b/tests_repl/issue1347.in @@ -1,2 +1,2 @@ -let blocked = blockify 2 (copy <| unflatten 8 8 (map i32.i64 <| iota 64)) +let blocked = blockify 2 (copy <| unflatten 8 8 (map i32.i64 <| iota (8*8))) let pre_transpose = map (map transpose) blocked diff --git a/unittests/Language/Futhark/SyntaxTests.hs b/unittests/Language/Futhark/SyntaxTests.hs index 6440bdb61c..c4e2829525 100644 --- a/unittests/Language/Futhark/SyntaxTests.hs +++ b/unittests/Language/Futhark/SyntaxTests.hs @@ -115,8 +115,8 @@ pSize :: Parser Size pSize = brackets $ choice - [ ConstSize <$> lexeme L.decimal, - NamedSize <$> pQualName + [ flip sizeFromInteger mempty <$> lexeme L.decimal, + flip sizeFromName mempty <$> pQualName ] pScalarNonFun :: Parser (ScalarTypeBase Size ()) @@ -132,8 +132,8 @@ pScalarNonFun = pTypeVar = TypeVar () <$> pUniqueness <*> pQualName <*> many pTypeArg pTypeArg = choice - [ TypeArgDim <$> pSize <*> pure mempty, - TypeArgType <$> pTypeArgType <*> pure mempty + [ TypeArgDim <$> pSize, + TypeArgType <$> pTypeArgType ] pTypeArgType = choice diff --git a/unittests/Language/Futhark/TypeChecker/TypesTests.hs b/unittests/Language/Futhark/TypeChecker/TypesTests.hs index 7f818c7576..283eb921a9 100644 --- a/unittests/Language/Futhark/TypeChecker/TypesTests.hs +++ b/unittests/Language/Futhark/TypeChecker/TypesTests.hs @@ -11,6 +11,7 @@ import Language.Futhark.Semantic import Language.Futhark.SyntaxTests () import Language.Futhark.TypeChecker (initialEnv) import Language.Futhark.TypeChecker.Monad +import Language.Futhark.TypeChecker.Terms import Language.Futhark.TypeChecker.Types import Test.Tasty import Test.Tasty.HUnit @@ -31,7 +32,7 @@ evalTest te expected = assertFailure $ "Expected error, got: " <> show actual_t where extract (_, svars, t, _) = (svars, t) - run = snd . runTypeM env mempty (mkInitialImport "") blankNameSource + run = snd . runTypeM env mempty (mkInitialImport "") blankNameSource checkSizeExp -- We hack up an environment with some predefined type -- abbreviations for testing. This is all prettyString sensitive to the -- specific unique names, so we have to be careful!