Skip to content

Commit

Permalink
Make consumption an effect on functions, rather than types.
Browse files Browse the repository at this point in the history
This is a breaking change, because until now we allowed functions like

    def f (a: *[]i32, b: []i32) = ...

where we could then pass in a tuple where in an application `f (x,y)`
the value `x` would be consumed, but not `y`.  However, this became
increasingly difficult to support as the language grew (and frankly,
it was always buggy).  With this commit, the syntax above is still
permitted, but it is interpreted as

    def f ((a,b): *([]i32, []i32)) = ...

i.e. the single tuple argument is consumed *as a whole*.  Long term we
can also consider amending the syntax or warning about cases where it
is misleading, but that is less urgent.

I've wanted to make this simplification for a long time, but I always
hit various snags.  Today I managed to make it work, and the next step
will be cleaning up the notion of "uniqueness" in return types as well
(it should be the more general notion of "aliases").
  • Loading branch information
athas committed Feb 12, 2023
1 parent 6cdd1b2 commit 8e11ee9
Show file tree
Hide file tree
Showing 32 changed files with 610 additions and 529 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Changed

* If part of a function parameter is marked as consuming ("unique"),
the *entire* parameter is now marked as consuming.

### Fixed

* A somewhat obscure simplification rule could mess up use of memory.
Expand Down
35 changes: 35 additions & 0 deletions docs/error-index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,41 @@ inserting copies to break the aliasing:
def main (xs: *[]i32) : (*[]i32, *[]i32) = (xs, copy xs)
.. _self-aliasing-arg:

"Argument passed for consuming parameter is self-aliased."
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Caused by programs like the following:

.. code-block:: futhark
def g (t: *([]i64, []i64)) = 0
def f n =
let x = iota n
in g (x,x)
The function ``g`` expects to consume two separate ``[]i64`` arrays,
but ``f`` passes it a tuple containing two references to the same
physical array. This is not allowed, as ``g`` must be allowed to
assume that components of consuming record- or tuple parameters have
no internal aliases. We can fix this by inserting copies to break the
aliasing:

.. code-block:: futhark
def f n =
let x = iota n
in g (copy (x,x))
Alternative, we could duplicate the expression producing the array:

.. code-block:: futhark
def f n =
g (iota n, iota n))
.. _consuming-parameter:

"Consuming parameter passed non-unique argument"
Expand Down
28 changes: 20 additions & 8 deletions docs/language-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1434,15 +1434,27 @@ prefixing it with an asterisk. For a return type, we can mark it as
def modify (a: *[]i32) (i: i32) (x: i32): *[]i32 =
a with [i] = a[i] + x

A parameter that is not consuming is called *observing*. In the
parameter declaration ``a: *[i32]``, the asterisk means that the
function ``modify`` has been given "ownership" of the array ``a``,
meaning that any caller of ``modify`` will never reference array ``a``
after the call again. This allows the ``with`` expression to perform
an in-place update. After a call ``modify a i x``, neither ``a`` or
any variable that *aliases* ``a`` may be used on any following
execution path.

If an asterisk is present at *any point* inside a tuple parameter
type, the parameter as a whole is considered consuming. For example::

def consumes_both ((a,b): (*[]i32,[]i32)) = ...

This is usually not desirable behaviour. Use multiple parameters
instead::

def consumes_first_arg (a: *[]i32) (b: []i32) = ...

For bulk in-place updates with multiple values, use the ``scatter``
function in the basis library. In the parameter declaration ``a:
*[i32]``, the asterisk means that the function ``modify`` has been
given "ownership" of the array ``a``, meaning that any caller of
``modify`` will never reference array ``a`` after the call again.
This allows the ``with`` expression to perform an in-place update.

After a call ``modify a i x``, neither ``a`` or any variable that
*aliases* ``a`` may be used on any following execution path.
function in the basis library.

Alias Analysis
~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions prelude/ad.fut
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
-- | Jacobian-Vector Product ("forward mode"), producing also the
-- primal result as the first element of the result tuple.
let jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) =
intrinsics.jvp2 (f, x, x')
intrinsics.jvp2 f x x'

-- | Vector-Jacobian Product ("reverse mode"), producing also the
-- primal result as the first element of the result tuple.
let vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) =
intrinsics.vjp2 (f, x, y')
intrinsics.vjp2 f x y'

-- | Jacobian-Vector Product ("forward mode").
let jvp 'a 'b (f: a -> b) (x: a) (x': a) : b =
Expand Down
6 changes: 3 additions & 3 deletions prelude/array.fut
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def reverse [n] 't (x: [n]t): [n]t = x[::-1]
-- **Work:** O(n).
--
-- **Span:** O(1).
def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = intrinsics.concat (xs, ys)
def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = intrinsics.concat xs ys

-- | An old-fashioned way of saying `++`.
def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[]t = xs ++ ys
Expand All @@ -83,7 +83,7 @@ def concat_to [n] [m] 't (k: i64) (xs: [n]t) (ys: [m]t): *[k]t = xs ++ ys :> [k]
--
-- Note: In most cases, `rotate` will be fused with subsequent
-- operations such as `map`, in which case it is free.
def rotate [n] 't (r: i64) (xs: [n]t): [n]t = intrinsics.rotate (r, xs)
def rotate [n] 't (r: i64) (xs: [n]t): [n]t = intrinsics.rotate r xs

-- | Construct an array of consecutive integers of the given length,
-- starting at 0.
Expand Down Expand Up @@ -143,7 +143,7 @@ def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): []t =
--
-- **Complexity:** O(1).
def unflatten [p] 't (n: i64) (m: i64) (xs: [p]t): [n][m]t =
intrinsics.unflatten (n, m, xs) :> [n][m]t
intrinsics.unflatten n m xs :> [n][m]t

-- | Like `unflatten`, but produces three dimensions.
def unflatten_3d [p] 't (n: i64) (m: i64) (l: i64) (xs: [p]t): [n][m][l]t =
Expand Down
30 changes: 15 additions & 15 deletions prelude/soacs.fut
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import "zip"
--
-- **Span:** *O(S(f))*
def map 'a [n] 'x (f: a -> x) (as: [n]a): *[n]x =
intrinsics.map (f, as)
intrinsics.map f as

-- | Apply the given function to each element of a single array.
--
Expand Down Expand Up @@ -104,7 +104,7 @@ def map5 'a 'b 'c 'd 'e [n] 'x (f: a -> b -> c -> d -> e -> x) (as: [n]a) (bs: [
-- Note that the complexity implies that parallelism in the combining
-- operator will *not* be exploited.
def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a =
intrinsics.reduce (op, ne, as)
intrinsics.reduce op ne as

-- | As `reduce`, but the operator must also be commutative. This is
-- potentially faster than `reduce`. For simple built-in operators,
Expand All @@ -115,7 +115,7 @@ def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a =
--
-- **Span:** *O(log(n) ✕ W(op))*
def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a =
intrinsics.reduce_comm (op, ne, as)
intrinsics.reduce_comm op ne as

-- | `h = hist op ne k is as` computes a generalised `k`-bin histogram
-- `h`, such that `h[i]` is the sum of those values `as[j]` for which
Expand All @@ -130,7 +130,7 @@ def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): a =
--
-- In practice, linear span only occurs if *k* is also very large.
def hist 'a [n] (op: a -> a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k]a =
intrinsics.hist_1d (1, map (\_ -> ne) (0..1..<k), op, ne, is, as)
intrinsics.hist_1d 1 (map (\_ -> ne) (0..1..<k)) op ne is as

-- | Like `hist`, but with initial contents of the histogram, and the
-- complexity is proportional only to the number of input elements,
Expand All @@ -143,15 +143,15 @@ def hist 'a [n] (op: a -> a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k
--
-- In practice, linear span only occurs if *k* is also very large.
def reduce_by_index 'a [k] [n] (dest : *[k]a) (f : a -> a -> a) (ne : a) (is : [n]i64) (as : [n]a) : *[k]a =
intrinsics.hist_1d (1, dest, f, ne, is, as)
intrinsics.hist_1d 1 dest f ne is as

-- | As `reduce_by_index`, but with two-dimensional indexes.
def reduce_by_index_2d 'a [k] [n] [m] (dest : *[k][m]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64)) (as : [n]a) : *[k][m]a =
intrinsics.hist_2d (1, dest, f, ne, is, as)
intrinsics.hist_2d 1 dest f ne is as

-- | As `reduce_by_index`, but with three-dimensional indexes.
def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a) (ne : a) (is : [n](i64,i64,i64)) (as : [n]a) : *[k][m][l]a =
intrinsics.hist_3d (1, dest, f, ne, is, as)
intrinsics.hist_3d 1 dest f ne is as

-- | Inclusive prefix scan. Has the same caveats with respect to
-- associativity and complexity as `reduce`.
Expand All @@ -160,7 +160,7 @@ def reduce_by_index_3d 'a [k] [n] [m] [l] (dest : *[k][m][l]a) (f : a -> a -> a)
--
-- **Span:** *O(log(n) ✕ W(op))*
def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a =
intrinsics.scan (op, ne, as)
intrinsics.scan op ne as

-- | Remove all those elements of `as` that do not satisfy the
-- predicate `p`.
Expand All @@ -169,7 +169,7 @@ def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a): *[n]a =
--
-- **Span:** *O(log(n) ✕ W(p))*
def filter [n] 'a (p: a -> bool) (as: [n]a): *[]a =
let (as', is) = intrinsics.partition (1, \x -> if p x then 0 else 1, as)
let (as', is) = intrinsics.partition 1 (\x -> if p x then 0 else 1) as
in as'[:is[0]]

-- | Split an array into those elements that satisfy the given
Expand All @@ -180,7 +180,7 @@ def filter [n] 'a (p: a -> bool) (as: [n]a): *[]a =
-- **Span:** *O(log(n) ✕ W(p))*
def partition [n] 'a (p: a -> bool) (as: [n]a): ([]a, []a) =
let p' x = if p x then 0 else 1
let (as', is) = intrinsics.partition (2, p', as)
let (as', is) = intrinsics.partition 2 p' as
in (as'[0:is[0]], as'[is[0]:n])

-- | Split an array by two predicates, producing three arrays.
Expand All @@ -190,7 +190,7 @@ def partition [n] 'a (p: a -> bool) (as: [n]a): ([]a, []a) =
-- **Span:** *O(log(n) ✕ (W(p1) + W(p2)))*
def partition2 [n] 'a (p1: a -> bool) (p2: a -> bool) (as: [n]a): ([]a, []a, []a) =
let p' x = if p1 x then 0 else if p2 x then 1 else 2
let (as', is) = intrinsics.partition (3, p', as)
let (as', is) = intrinsics.partition 3 p' as
in (as'[0:is[0]], as'[is[0]:is[0]+is[1]], as'[is[0]+is[1]:n])

-- | Return `true` if the given function returns `true` for all
Expand Down Expand Up @@ -223,7 +223,7 @@ def any [n] 'a (f: a -> bool) (as: [n]a): bool =
--
-- **Span:** *O(1)*
def spread 't [n] (k: i64) (x: t) (is: [n]i64) (vs: [n]t): *[k]t =
intrinsics.scatter (map (\_ -> x) (0..1..<k), is, vs)
intrinsics.scatter (map (\_ -> x) (0..1..<k)) is vs

-- | Like `spread`, but takes an array indicating the initial values,
-- and has different work complexity.
Expand All @@ -232,7 +232,7 @@ def spread 't [n] (k: i64) (x: t) (is: [n]i64) (vs: [n]t): *[k]t =
--
-- **Span:** *O(1)*
def scatter 't [k] [n] (dest: *[k]t) (is: [n]i64) (vs: [n]t): *[k]t =
intrinsics.scatter (dest, is, vs)
intrinsics.scatter dest is vs

-- | `scatter_2d as is vs` is the equivalent of a `scatter` on a 2-dimensional
-- array.
Expand All @@ -241,7 +241,7 @@ def scatter 't [k] [n] (dest: *[k]t) (is: [n]i64) (vs: [n]t): *[k]t =
--
-- **Span:** *O(1)*
def scatter_2d 't [k] [n] [l] (dest: *[k][n]t) (is: [l](i64, i64)) (vs: [l]t): *[k][n]t =
intrinsics.scatter_2d (dest, is, vs)
intrinsics.scatter_2d dest is vs

-- | `scatter_3d as is vs` is the equivalent of a `scatter` on a 3-dimensional
-- array.
Expand All @@ -250,4 +250,4 @@ def scatter_2d 't [k] [n] [l] (dest: *[k][n]t) (is: [l](i64, i64)) (vs: [l]t): *
--
-- **Span:** *O(1)*
def scatter_3d 't [k] [n] [o] [l] (dest: *[k][n][o]t) (is: [l](i64, i64, i64)) (vs: [l]t): *[k][n][o]t =
intrinsics.scatter_3d (dest, is, vs)
intrinsics.scatter_3d dest is vs
4 changes: 2 additions & 2 deletions prelude/zip.fut
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
-- depended upon by soacs.fut. So we just define a quick-and-dirty
-- internal one here that uses the intrinsic version.
local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a): [n]x =
intrinsics.map (f, as)
intrinsics.map f as

-- | Construct an array of pairs from two arrays.
def zip [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) =
intrinsics.zip (as, bs)
intrinsics.zip as bs

-- | Construct an array of pairs from two arrays.
def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b): *[n](a,b) =
Expand Down
20 changes: 12 additions & 8 deletions src/Futhark/Doc/Generator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ valBindHtml name (ValBind _ _ retdecl (Info rettype) tparams params _ _ _ _) = d
map typeParamName tparams
++ map identName (S.toList $ mconcat $ map patIdents params)
rettype' <- noLink' $ maybe (retTypeHtml rettype) typeExpHtml retdecl
params' <- noLink' $ mapM patternHtml params
params' <- noLink' $ mapM paramHtml params
pure
( keyword "val " <> (H.span ! A.class_ "decl_name") name,
tparams',
Expand Down Expand Up @@ -493,6 +493,10 @@ synopsisValBindBind (name, BoundV tps t) = do
<> ": "
<> t'

dietHtml :: Diet -> Html
dietHtml Consume = "*"
dietHtml Observe = ""

typeHtml :: StructType -> DocM Html
typeHtml t = case t of
Array _ u shape et -> do
Expand All @@ -513,14 +517,14 @@ typeHtml t = case t of
targs' <- mapM typeArgHtml targs
et' <- qualNameHtml et
pure $ prettyU u <> et' <> mconcat (map (" " <>) targs')
Scalar (Arrow _ pname t1 t2) -> do
Scalar (Arrow _ pname d t1 t2) -> do
t1' <- typeHtml t1
t2' <- retTypeHtml t2
pure $ case pname of
Named v ->
parens (vnameHtml v <> ": " <> t1') <> " -> " <> t2'
parens (vnameHtml v <> ": " <> dietHtml d <> t1') <> " -> " <> t2'
Unnamed ->
t1' <> " -> " <> t2'
dietHtml d <> t1' <> " -> " <> t2'
Scalar (Sum cs) -> pipes <$> mapM ppClause (sortConstrs cs)
where
ppClause (n, ts) = joinBy " " . (ppConstr n :) <$> mapM typeHtml ts
Expand Down Expand Up @@ -688,12 +692,12 @@ vnameLink' (VName _ tag) current file =
then "#" ++ show tag
else relativise file current ++ ".html#" ++ show tag

patternHtml :: Pat -> DocM Html
patternHtml pat = do
let (pat_param, t) = patternParam pat
paramHtml :: Pat -> DocM Html
paramHtml pat = do
let (pat_param, d, t) = patternParam pat
t' <- typeHtml t
pure $ case pat_param of
Named v -> parens (vnameHtml v <> ": " <> t')
Named v -> parens (vnameHtml v <> ": " <> dietHtml d <> t')
Unnamed -> t'

relativise :: FilePath -> FilePath -> FilePath
Expand Down
Loading

0 comments on commit 8e11ee9

Please sign in to comment.