Skip to content

Commit

Permalink
Use more structural sizes in prelude functions. (diku-dk#1938)
Browse files Browse the repository at this point in the history
Changes the type of `flatten`/`unflatten` and their multidimensional versions.

Changes the type of `split`.

Adds `resize` - not a big fan of the name.

Removes `concat_to` and `flatten_to`.  These were workarounds for
previous type system limitations that are now much less pressing.
  • Loading branch information
athas authored and razetime committed May 27, 2023
1 parent 44d7151 commit 710fed5
Show file tree
Hide file tree
Showing 70 changed files with 142 additions and 195 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
* Arbitrary expressions of type `i64` are now allowed as sizes. Work
by Lubin Bailly.

* New prelude function `resize`.

### Removed

* The prelude functions `concat_to` and `flatten_to`. They are often
not necessary now, and otherwise `resize` is available.

### Changed

* The prelude functions `flatten` and `unflatten` (and their
multidimensional variants) have more restrictive types.
multidimensional variants), as well as `split`, now have more
precise types.

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion futhark-benchmarks
Submodule futhark-benchmarks updated 41 files
+2 −2 accelerate/canny/canny.fut
+1 −1 accelerate/hashcat/hashcat.fut
+1 −2 accelerate/nbody/nbody-gui.fut
+4 −7 accelerate/nbody/radixtree.fut
+6 −6 accelerate/smoothlife/smoothlife.fut
+6 −0 external-data.txt
+1 −3 finpar/LocVolCalib.fut
+1 −1 finpar/OptionPricing.fut
+2 −2 jgf/crypt/crypt.fut
+8 −8 micro/reduce-segmented.fut
+1 −1 micro/reduce_by_index.fut
+2 −2 micro/transpose.fut
+1 −1 misc/bfast/bfast-cloudy.fut
+1 −1 misc/bfast/bfast.fut
+1 −1 misc/bfast/lib/github.com/diku-dk/linalg/linalg.fut
+1 −1 misc/buddhabrot/buddhabrot.fut
+1 −2 misc/knn-by-kdtree/buildKDtree.fut
+1 −1 misc/knn-by-kdtree/driver-knn.fut
+1 −1 misc/knn-by-kdtree/knn-iteration.fut
+3 −4 misc/knn-by-kdtree/util.fut
+2 −2 misc/life/life.fut
+1 −1 misc/poseidon/poseidon.fut
+3 −4 parboil/histo/histo.fut
+1 −1 parboil/lbm/lbm.fut
+1 −1 parboil/tpacf/tpacf.fut
+0 −8,911 pbbs/breadthFirstSearch/lib/github.com/diku-dk/segmented/segmented_tests.c
+11 −0 pbbs/fut2pbbs.c
+1 −0 pbbs/maximalMatching/data/2Dgrid_E_64000000.in
+1 −0 pbbs/maximalMatching/data/2Dgrid_E_64000000.out
+1 −0 pbbs/maximalMatching/data/rMatGraph_E_10_20000000.in
+1 −0 pbbs/maximalMatching/data/rMatGraph_E_10_20000000.out
+1 −0 pbbs/maximalMatching/data/randLocalGraph_E_10_20000000.in
+1 −0 pbbs/maximalMatching/data/randLocalGraph_E_10_20000000.out
+83 −0 pbbs/maximalMatching/maximalMatching.fut
+7 −1 pbbs/pbbs2fut.c
+4 −7 pbbs/ray/radixtree.fut
+27 −29 rodinia/backprop/backprop.fut
+1 −1 rodinia/hotspot/hotspot.fut
+1 −1 rodinia/lud/lud.fut
+1 −2 rodinia/nw/nw.fut
+1 −1 rodinia/pathfinder/pathfinder.fut
29 changes: 12 additions & 17 deletions prelude/array.fut
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def take [n] 't (i: i64) (x: [n]t): [i]t = x[0:i]
-- **Complexity:** O(1).
def drop [n] 't (i: i64) (x: [n]t): [n-i]t = x[i:]

-- | Statically change the size of an array. Fail at runtime if the
-- imposed size does not match the actual size. Essentially syntactic
-- sugar for a size coercion.
def resize [m] 't (n: i64) (xs: [m]t) : [n]t = xs :> [n]t

-- | Split an array at a given position.
--
-- **Complexity:** O(1).
def split [n] 't (i: i64) (xs: [n]t): ([i]t, [n-i]t) =
(xs[0:i], xs[i:])
def split [n][m] 't (xs: [n+m]t): ([n]t, [m]t) =
(xs[0:n], xs[n:n+m] :> [m]t)

-- | Return the elements of the array in reverse order.
--
Expand All @@ -67,11 +72,6 @@ def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = intrinsics.concat xs ys
-- | An old-fashioned way of saying `++`.
def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys

-- | Concatenation where the result has a predetermined size. If the
-- provided size is wrong, the function will fail with a run-time
-- error.
def concat_to [n] [m] 't (k: i64) (xs: [n]t) (ys: [m]t): *[k]t = xs ++ ys :> [k]t

-- | Rotate an array some number of elements to the left. A negative
-- rotation amount is also supported.
--
Expand Down Expand Up @@ -126,11 +126,6 @@ def copy 't (a: t): *t =
def flatten [n][m] 't (xs: [n][m]t): [n*m]t =
intrinsics.flatten xs

-- | Like `flatten`@term, but where the final size is known. Fails at
-- runtime if the provided size is wrong.
def flatten_to [n][m] 't (l: i64) (xs: [n][m]t): [l]t =
flatten xs :> [l]t

-- | Like `flatten`, but on the outer three dimensions of an array.
def flatten_3d [n][m][l] 't (xs: [n][m][l]t): [n*m*l]t =
flatten (flatten xs)
Expand All @@ -142,16 +137,16 @@ def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): [n*m*l*k]t =
-- | Splits the outer dimension of an array in two.
--
-- **Complexity:** O(1).
def unflatten 't (n: i64) (m: i64) (xs: [n*m]t): [n][m]t =
def unflatten 't [n][m] (xs: [n*m]t): [n][m]t =
intrinsics.unflatten n m xs :> [n][m]t

-- | Like `unflatten`, but produces three dimensions.
def unflatten_3d 't (n: i64) (m: i64) (l: i64) (xs: [n*m*l]t): [n][m][l]t =
unflatten n m (unflatten (n*m) l xs)
def unflatten_3d 't [n][m][l] (xs: [n*m*l]t): [n][m][l]t =
unflatten (unflatten xs)

-- | Like `unflatten`, but produces four dimensions.
def unflatten_4d 't (n: i64) (m: i64) (l: i64) (k: i64) (xs: [n*m*l*k]t): [n][m][l][k]t =
unflatten n m (unflatten_3d (n*m) l k xs)
def unflatten_4d 't [n][m][l][k] (xs: [n*m*l*k]t): [n][m][l][k]t =
unflatten (unflatten_3d xs)

-- | Transpose an array.
--
Expand Down
6 changes: 2 additions & 4 deletions tests/ad/concat0.fut
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
-- output { [1,2,3,4,5,6] }

entry f_jvp xs ys : []i32 =
let m = length xs + length ys
in jvp (uncurry (concat_to m)) (xs,ys) (xs, ys)
jvp (uncurry concat) (xs,ys) (xs, ys)

-- ==
-- entry: f_vjp
-- compiled input { [1,2,3] [4,5,6] }
-- output { [1,2,3] [4,5,6] }

entry f_vjp xs ys : ([]i32, []i32) =
let m = length xs + length ys
in vjp (uncurry (concat_to m)) (xs,ys) (concat_to m xs ys)
vjp (uncurry concat) (xs,ys) (concat xs ys)
8 changes: 4 additions & 4 deletions tests/ad/reshape0.fut
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
-- compiled input { 2i64 2i64 [1,2,3,4] }
-- output { [[1,2],[3,4]] }

entry f_jvp n m (xs: []i32) =
jvp (unflatten n m) xs xs
entry f_jvp n m (xs: [n*m]i32) =
jvp unflatten xs xs

-- ==
-- entry: f_vjp
-- compiled input { 2i64 2i64 [1,2,3,4] }
-- output { [1,2,3,4] }

entry f_vjp n m (xs: []i32) =
vjp (unflatten n m) xs (unflatten n m xs)
entry f_vjp n m (xs: [n*m]i32) =
vjp unflatten xs (unflatten xs)
4 changes: 2 additions & 2 deletions tests/ad/scan3.fut
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def onehot_2d n m x y =
entry fwd_J [n] (input: [n][4]f32) : [n][4][n][4]f32 =
let input = fromarrs input
in tabulate (n*4) (\i -> jvp primal input (fromarrs (onehot_2d n 4 (i/4) (i%4))))
|> map toarrs |> transpose |> map transpose |> map (map (unflatten n 4))
|> map toarrs |> transpose |> map transpose |> map (map unflatten)

entry rev_J [n] (input: [n][4]f32) : [n][4][n][4]f32 =
let input = fromarrs input
in tabulate (n*4) (\i -> vjp primal input (fromarrs (onehot_2d n 4 (i/4) (i%4))))
|> unflatten n 4 |> map (map toarrs)
|> unflatten |> map (map toarrs)
24 changes: 12 additions & 12 deletions tests/ad/scan6.fut
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
-- entry: fwd_J rev_J
-- compiled input { [[1f32, 2f32], [4f32, 3f32], [3f32, 4f32], [4f32, 2f32]] }
-- output {
-- [[[[1f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]],
-- [[0f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]]],
-- [[[3f32, 0f32], [1f32, 1f32], [0f32, 0f32], [0f32, 0f32]],
-- [[0f32, 3f32], [0f32, 2f32], [0f32, 0f32], [0f32, 0f32]]],
-- [[[12f32, 0f32], [4f32, 4f32], [1f32, 7f32], [0f32, 0f32]],
-- [[0f32, 12f32], [0f32, 8f32], [0f32, 6f32], [0f32, 0f32]]],
-- [[[24f32, 0f32], [8f32, 8f32], [2f32, 14f32], [1f32, 31f32]],
-- [[[[1f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]],
-- [[0f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]]],
-- [[[3f32, 0f32], [1f32, 1f32], [0f32, 0f32], [0f32, 0f32]],
-- [[0f32, 3f32], [0f32, 2f32], [0f32, 0f32], [0f32, 0f32]]],
-- [[[12f32, 0f32], [4f32, 4f32], [1f32, 7f32], [0f32, 0f32]],
-- [[0f32, 12f32], [0f32, 8f32], [0f32, 6f32], [0f32, 0f32]]],
-- [[[24f32, 0f32], [8f32, 8f32], [2f32, 14f32], [1f32, 31f32]],
-- [[0f32, 24f32], [0f32, 16f32], [0f32, 12f32], [0f32, 24f32]]]]
-- }

Expand All @@ -26,17 +26,17 @@ def onehot_2d n m x y =
entry fwd_J [n] (input: [n][2]f32) =
let input = fromarrs input
in tabulate (n*2) (\i -> jvp primal input (fromarrs (onehot_2d n 2 (i/2) (i%2))))
|> map toarrs |> transpose |> map transpose |> map (map (unflatten n 2))
|> map toarrs |> transpose |> map transpose |> map (map unflatten)

entry rev_J [n] (input: [n][2]f32) =
let input = fromarrs input
in tabulate (n*2) (\i -> vjp primal input (fromarrs (onehot_2d n 2 (i/2) (i%2))))
|> unflatten n 2 |> map (map toarrs)
|> unflatten |> map (map toarrs)

-- ==
-- entry: fwd_J2 rev_J2
-- compiled input { [[1f32,2f32,3f32,4f32,5f32,6f32],[6f32,5f32,4f32,3f32,2f32,1f32],[4f32,5f32,6f32,1f32,2f32,3f32],[3f32,2f32,1f32,6f32,5f32,4f32]] }
-- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[4f32, 3f32, 0f32, 0f32, 0f32, 0f32], [1f32, 0f32, 1f32, 2f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[2f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 1f32, 0f32, 0f32, 1f32, 2f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 4f32, 0f32, 3f32, 0f32], [0f32, 0f32, 3f32, 5f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 4f32, 0f32, 3f32], [0f32, 0f32, 4f32, 6f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 2f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 3f32, 5f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 2f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 4f32, 6f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[26f32, 19f32, 0f32, 0f32, 0f32, 0f32], [6f32, 1f32, 6f32, 12f32, 1f32, 2f32], [1f32, 0f32, 16f32, 9f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[14f32, 9f32, 0f32, 0f32, 0f32, 0f32], [2f32, 3f32, 2f32, 4f32, 3f32, 6f32], [0f32, 1f32, 0f32, 0f32, 16f32, 9f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 26f32, 0f32, 19f32, 0f32], [0f32, 0f32, 18f32, 30f32, 3f32, 5f32], [0f32, 0f32, 27f32, 11f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 26f32, 0f32, 19f32], [0f32, 0f32, 24f32, 36f32, 4f32, 6f32], [0f32, 0f32, 34f32, 14f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 14f32, 0f32, 9f32, 0f32], [0f32, 0f32, 6f32, 10f32, 9f32, 15f32], [0f32, 0f32, 0f32, 0f32, 27f32, 11f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 14f32, 0f32, 9f32], [0f32, 0f32, 8f32, 12f32, 12f32, 18f32], [0f32, 0f32, 0f32, 0f32, 34f32, 14f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[110f32, 73f32, 0f32, 0f32, 0f32, 0f32], [18f32, 19f32, 18f32, 36f32, 19f32, 38f32], [1f32, 6f32, 16f32, 9f32, 96f32, 54f32], [1f32, 0f32, 109f32, 64f32, 0f32, 0f32]], [[186f32, 131f32, 0f32, 0f32, 0f32, 0f32], [38f32, 17f32, 38f32, 76f32, 17f32, 34f32], [5f32, 4f32, 80f32, 45f32, 64f32, 36f32], [0f32, 1f32, 0f32, 0f32, 109f32, 64f32]], [[0f32, 0f32, 110f32, 0f32, 73f32, 0f32], [0f32, 0f32, 54f32, 90f32, 57f32, 95f32], [0f32, 0f32, 27f32, 11f32, 162f32, 66f32], [0f32, 0f32, 173f32, 87f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 110f32, 0f32, 73f32], [0f32, 0f32, 72f32, 108f32, 76f32, 114f32], [0f32, 0f32, 34f32, 14f32, 204f32, 84f32], [0f32, 0f32, 218f32, 110f32, 0f32, 0f32]], [[0f32, 0f32, 186f32, 0f32, 131f32, 0f32], [0f32, 0f32, 114f32, 190f32, 51f32, 85f32], [0f32, 0f32, 135f32, 55f32, 108f32, 44f32], [0f32, 0f32, 0f32, 0f32, 173f32, 87f32]], [[0f32, 0f32, 0f32, 186f32, 0f32, 131f32], [0f32, 0f32, 152f32, 228f32, 68f32, 102f32], [0f32, 0f32, 170f32, 70f32, 136f32, 56f32], [0f32, 0f32, 0f32, 0f32, 218f32, 110f32]]]] }
-- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[4f32, 3f32, 0f32, 0f32, 0f32, 0f32], [1f32, 0f32, 1f32, 2f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[2f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 1f32, 0f32, 0f32, 1f32, 2f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 4f32, 0f32, 3f32, 0f32], [0f32, 0f32, 3f32, 5f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 4f32, 0f32, 3f32], [0f32, 0f32, 4f32, 6f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 2f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 3f32, 5f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 2f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 4f32, 6f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[26f32, 19f32, 0f32, 0f32, 0f32, 0f32], [6f32, 1f32, 6f32, 12f32, 1f32, 2f32], [1f32, 0f32, 16f32, 9f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[14f32, 9f32, 0f32, 0f32, 0f32, 0f32], [2f32, 3f32, 2f32, 4f32, 3f32, 6f32], [0f32, 1f32, 0f32, 0f32, 16f32, 9f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 26f32, 0f32, 19f32, 0f32], [0f32, 0f32, 18f32, 30f32, 3f32, 5f32], [0f32, 0f32, 27f32, 11f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 26f32, 0f32, 19f32], [0f32, 0f32, 24f32, 36f32, 4f32, 6f32], [0f32, 0f32, 34f32, 14f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 14f32, 0f32, 9f32, 0f32], [0f32, 0f32, 6f32, 10f32, 9f32, 15f32], [0f32, 0f32, 0f32, 0f32, 27f32, 11f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 14f32, 0f32, 9f32], [0f32, 0f32, 8f32, 12f32, 12f32, 18f32], [0f32, 0f32, 0f32, 0f32, 34f32, 14f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[110f32, 73f32, 0f32, 0f32, 0f32, 0f32], [18f32, 19f32, 18f32, 36f32, 19f32, 38f32], [1f32, 6f32, 16f32, 9f32, 96f32, 54f32], [1f32, 0f32, 109f32, 64f32, 0f32, 0f32]], [[186f32, 131f32, 0f32, 0f32, 0f32, 0f32], [38f32, 17f32, 38f32, 76f32, 17f32, 34f32], [5f32, 4f32, 80f32, 45f32, 64f32, 36f32], [0f32, 1f32, 0f32, 0f32, 109f32, 64f32]], [[0f32, 0f32, 110f32, 0f32, 73f32, 0f32], [0f32, 0f32, 54f32, 90f32, 57f32, 95f32], [0f32, 0f32, 27f32, 11f32, 162f32, 66f32], [0f32, 0f32, 173f32, 87f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 110f32, 0f32, 73f32], [0f32, 0f32, 72f32, 108f32, 76f32, 114f32], [0f32, 0f32, 34f32, 14f32, 204f32, 84f32], [0f32, 0f32, 218f32, 110f32, 0f32, 0f32]], [[0f32, 0f32, 186f32, 0f32, 131f32, 0f32], [0f32, 0f32, 114f32, 190f32, 51f32, 85f32], [0f32, 0f32, 135f32, 55f32, 108f32, 44f32], [0f32, 0f32, 0f32, 0f32, 173f32, 87f32]], [[0f32, 0f32, 0f32, 186f32, 0f32, 131f32], [0f32, 0f32, 152f32, 228f32, 68f32, 102f32], [0f32, 0f32, 170f32, 70f32, 136f32, 56f32], [0f32, 0f32, 0f32, 0f32, 218f32, 110f32]]]] }
def mm2by2 (a1, b1, c1, d1)
(a2, b2, c2, d2) : (f32,f32,f32,f32) =
( a1*a2 + b1*c2
Expand Down Expand Up @@ -65,9 +65,9 @@ def toarrs2 = map (\((a,b),(c,d,e,f)) -> [a,b,c,d,e,f])
entry fwd_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 =
let input = fromarrs2 input
in tabulate (n*6) (\i -> jvp primal2 input (fromarrs2 (onehot_2d n 6 (i/6) (i%6))))
|> map toarrs2 |> transpose |> map transpose |> map (map (unflatten n 6))
|> map toarrs2 |> transpose |> map transpose |> map (map unflatten)

entry rev_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 =
let input = fromarrs2 input
in tabulate (n*6) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 6 (i/6) (i%6))))
|> unflatten n 6 |> map (map toarrs2)
|> unflatten |> map (map toarrs2)
4 changes: 2 additions & 2 deletions tests/ad/scan8.fut
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def onehot_2d n m x y =
entry fwd [n] (input: [n][9]f32) : [n][9][n][9]f32 =
let input = fromarrs3 input
in tabulate (n*9) (\i -> jvp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9))))
|> map toarrs3 |> transpose |> map transpose |> map (map (unflatten n 9))
|> map toarrs3 |> transpose |> map transpose |> map (map unflatten)

entry rev [n] (input: [n][9]f32) : [n][9][n][9]f32 =
let input = fromarrs3 input
in tabulate (n*9) (\i -> vjp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9))))
|> unflatten n 9 |> map (map toarrs3)
|> unflatten |> map (map toarrs3)
4 changes: 2 additions & 2 deletions tests/ad/scan9.fut
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,9 @@ def onehot_2d n m x y =
entry fwd [n] (input: [n][16]f32) : [n][16][n][16]f32 =
let input = fromarrs2 input
in tabulate (n*16) (\i -> jvp primal2 input (fromarrs2 (onehot_2d n 16 (i/16) (i%16))))
|> map toarrs2 |> transpose |> map transpose |> map (map (unflatten n 16))
|> map toarrs2 |> transpose |> map transpose |> map (map unflatten)

entry rev [n] (input: [n][16]f32) : [n][16][n][16]f32 =
let input = fromarrs2 input
in tabulate (n*16) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 16 (i/16) (i%16))))
|> unflatten n 16 |> map (map toarrs2)
|> unflatten |> map (map toarrs2)
20 changes: 10 additions & 10 deletions tests/ad/scangenbenchtests.fut
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def onehot_2d n m x y =
def fwd_J2 [n] (input: [n][4]i32) : [n][4][n][4]i32 =
let input = fromarrs2 input
in tabulate (n*4) (\i -> jvp primal2 input (fromarrs2 (onehot_2d n 4 (i/4) (i%4))))
|> map toarrs2 |> transpose |> map transpose |> map (map (unflatten n 4))
|> map toarrs2 |> transpose |> map transpose |> map (map unflatten)

def rev_J2 [n] (input: [n][4]i32) : [n][4][n][4]i32 =
let input = fromarrs2 input
in tabulate (n*4) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 4 (i/4) (i%4))))
|> unflatten n 4 |> map (map toarrs2)
|> unflatten |> map (map toarrs2)

entry testmm2by2 [n] (input: [n][4]i32) =
let fwd = fwd_J2 input
Expand Down Expand Up @@ -73,12 +73,12 @@ def toarrs3 = map (\(a,b,c,d,e,f,g,h,i) -> [a,b,c,d,e,f,g,h,i])
def fwd_J3 [n] (input: [n][9]i32) : [n][9][n][9]i32 =
let input = fromarrs3 input
in tabulate (n*9) (\i -> jvp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9))))
|> map toarrs3 |> transpose |> map transpose |> map (map (unflatten n 9))
|> map toarrs3 |> transpose |> map transpose |> map (map unflatten)

def rev_J3 [n] (input: [n][9]i32) : [n][9][n][9]i32 =
let input = fromarrs3 input
in tabulate (n*9) (\i -> vjp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9))))
|> unflatten n 9 |> map (map toarrs3)
|> unflatten |> map (map toarrs3)

entry testmm3by3 [n] (input: [n][9]i32) =
let fwd = fwd_J3 input
Expand Down Expand Up @@ -117,12 +117,12 @@ def toarrs4 = map (\(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) -> [a,b,c,d,e,f,g,h,i,j,k,
def fwd_J4 [n] (input: [n][16]i32) : [n][16][n][16]i32 =
let input = fromarrs4 input
in tabulate (n*16) (\i -> jvp primal4 input (fromarrs4 (onehot_2d n 16 (i/16) (i%16))))
|> map toarrs4 |> transpose |> map transpose |> map (map (unflatten n 16))
|> map toarrs4 |> transpose |> map transpose |> map (map unflatten)

def rev_J4 [n] (input: [n][16]i32) : [n][16][n][16]i32 =
let input = fromarrs4 input
in tabulate (n*16) (\i -> vjp primal4 input (fromarrs4 (onehot_2d n 16 (i/16) (i%16))))
|> unflatten n 16 |> map (map toarrs4)
|> unflatten |> map (map toarrs4)

entry testmm4by4 [n] (input: [n][16]i32) =
let fwd = fwd_J4 input
Expand All @@ -138,12 +138,12 @@ def toarrslin = map (\(a,b) -> [a,b])
def fwd_Jlin [n] (input: [n][2]i32) =
let input = fromarrslin input
in tabulate (n*2) (\i -> jvp primallin input (fromarrslin (onehot_2d n 2 (i/2) (i%2))))
|> map toarrslin |> transpose |> map transpose |> map (map (unflatten n 2))
|> map toarrslin |> transpose |> map transpose |> map (map unflatten)

def rev_Jlin [n] (input: [n][2]i32) =
let input = fromarrslin input
in tabulate (n*2) (\i -> vjp primallin input (fromarrslin (onehot_2d n 2 (i/2) (i%2))))
|> unflatten n 2 |> map (map toarrslin)
|> unflatten |> map (map toarrslin)

entry testlin [n] (input: [n][2]i32) =
let fwd = fwd_Jlin input
Expand All @@ -170,12 +170,12 @@ def toarrslin2 = map (\((a,b),(c,d,e,f)) -> [a,b,c,d,e,f])
def fwd_Jlin2 [n] (input: [n][6]i32) : [n][6][n][6]i32 =
let input = fromarrslin2 input
in tabulate (n*6) (\i -> jvp primallin2 input (fromarrslin2 (onehot_2d n 6 (i/6) (i%6))))
|> map toarrslin2 |> transpose |> map transpose |> map (map (unflatten n 6))
|> map toarrslin2 |> transpose |> map transpose |> map (map unflatten)

def rev_Jlin2 [n] (input: [n][6]i32) : [n][6][n][6]i32 =
let input = fromarrslin2 input
in tabulate (n*6) (\i -> vjp primallin2 input (fromarrslin2 (onehot_2d n 6 (i/6) (i%6))))
|> unflatten n 6 |> map (map toarrslin2)
|> unflatten |> map (map toarrslin2)

entry testlin2by2 [n] (input: [n][6]i32) =
let fwd = fwd_Jlin2 input
Expand Down
Loading

0 comments on commit 710fed5

Please sign in to comment.