Skip to content

Commit

Permalink
Fix #1878.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Feb 15, 2023
1 parent 5f5ccb6 commit 7ed74ec
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

* Corner case optimisation for mapping over `iota` (#1874).

* AD for certain combinations of `map` and indexing (#1878).

## [0.23.1]

### Added
Expand Down
26 changes: 20 additions & 6 deletions src/Futhark/AD/Rev/Map.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{-# LANGUAGE TypeFamilies #-}

-- | VJP transformation for Map SOACs. This is a pretty complicated
-- case due to the possibility of free variables.
module Futhark.AD.Rev.Map (vjpMap) where

import Control.Monad
Expand Down Expand Up @@ -73,14 +75,16 @@ withAcc inputs m = do
subAD $ mkLambda (cert_params ++ acc_params) $ m $ map paramName acc_params
letTupExp "withhacc_res" $ WithAcc inputs acc_lam

-- | Perform VJP on a Map. The 'Adj' list is the adjoints of the
-- result of the map.
vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap ops res_adjs _ w map_lam as
| Just res_ivs <- mapM isSparse res_adjs = returnSweepCode $ do
-- Since at most only a constant number of adjoint are nonzero
-- (length res_ivs), there is no need for the return sweep code to
-- contain a Map at all.

free <- filterM isActive $ namesToList $ freeIn map_lam
free <- filterM isActive $ namesToList $ freeIn map_lam `namesSubtract` namesFromList as
free_ts <- mapM lookupType free
let adjs_for = map paramName (lambdaParams map_lam) ++ free
adjs_ts = map paramType (lambdaParams map_lam) ++ free_ts
Expand All @@ -96,7 +100,13 @@ vjpMap ops res_adjs _ w map_lam as
forM_ (zip as adjs_ts) $ \(a, t) -> do
scratch <- letSubExp "oo_scratch" =<< eBlank t
updateAdjIndex a (OutOfBounds, adj_i) scratch
first subExpsRes . adjsReps <$> mapM lookupAdj as
-- We must make sure that all free variables have the same
-- representation in the oo-branch as in the ib-branch.
-- In practice we do this by manifesting the adjoint.
-- This is probably efficient, since the adjoint of a free
-- variable is probably either a scalar or an accumulator.
forM_ free $ \v -> insAdj v =<< adjVal =<< lookupAdj v
first subExpsRes . adjsReps <$> mapM lookupAdj (as <> free)
inBounds res_i adj_i adj_v = subAD . buildRenamedBody $ do
forM_ (zip (lambdaParams map_lam) as) $ \(p, a) -> do
a_t <- lookupType a
Expand All @@ -105,15 +115,19 @@ vjpMap ops res_adjs _ w map_lam as
adj_elems <-
fmap (map resSubExp) . bodyBind . lambdaBody
=<< vjpLambda ops (oneHot res_i (AdjVal adj_v)) adjs_for map_lam
forM_ (zip as adj_elems) $ \(a, a_adj_elem) -> do
let (as_adj_elems, free_adj_elems) = splitAt (length as) adj_elems
forM_ (zip as as_adj_elems) $ \(a, a_adj_elem) ->
updateAdjIndex a (AssumeBounds, adj_i) a_adj_elem
first subExpsRes . adjsReps <$> mapM lookupAdj as
forM_ (zip free free_adj_elems) $ \(v, adj_se) -> do
adj_se_v <- letExp "adj_v" (BasicOp $ SubExp adj_se)
insAdj v adj_se_v
first subExpsRes . adjsReps <$> mapM lookupAdj (as <> free)

-- Generate an iteration of the map function for every
-- position. This is a bit inefficient - probably we could do
-- some deduplication.
forPos res_i (check, adj_i, adj_v) = do
as_adj <-
adjs <-
case check of
CheckBounds b -> do
(obbranch, mkadjs) <- ooBounds adj_i
Expand All @@ -129,7 +143,7 @@ vjpMap ops res_adjs _ w map_lam as
OutOfBounds ->
mapM lookupAdj as

zipWithM setAdj as as_adj
zipWithM setAdj (as <> free) adjs

-- Generate an iteration of the map function for every result.
forRes res_i = mapM_ (forPos res_i)
Expand Down
29 changes: 29 additions & 0 deletions tests/ad/map6.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
-- #1878
-- ==
-- entry: fwd_J rev_J
-- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] }
-- output { [[0.0, 2.0, 3.0, 4.0],
-- [0.0, 0.0, 1.0, 1.0],
-- [0.0, 0.0, 0.0, 1.0],
-- [0.0, 0.0, 0.0, 0.0],
-- [-4.0, -6.0, -7.0, -8.0],
-- [0.0, 0.0, -1.0, -1.0],
-- [0.0, 0.0, 0.0, -1.0],
-- [0.0, 0.0, 0.0, 0.0]]
-- }

def obj (x : [8]f64) =
#[unsafe] -- For simplicity of generated code.
let col_w_pre_red =
tabulate_3d 4 2 4 (\ k i j -> x[k+j]*x[i+j])
let col_w_red =
map (map f64.sum) col_w_pre_red
let col_eq : [4]f64 =
map (\w -> w[0] - w[1]) col_w_red
in col_eq

entry fwd_J (x : [8]f64) =
tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1))

entry rev_J (x : [8]f64) =
transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1)))
22 changes: 22 additions & 0 deletions tests/ad/map7.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- #1878. The interesting thing here is that the sparse adjoint also
-- has active free variables.
-- ==
-- entry: fwd_J rev_J
-- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] }
-- output { [0.0, 0.0, 0.0, 0.0, -4.0, 0.0, 0.0, 0.0] }

def obj (x : [8]f64) =
#[unsafe] -- For simplicity of generated code.
let col_w_pre_red =
tabulate_3d 4 2 4 (\ k i j -> x[k+j]*x[i+j])
let col_w_red =
map (map f64.sum) col_w_pre_red
let col_eq : [4]f64 =
map (\w -> w[0] - w[1]) col_w_red
in f64.maximum col_eq

entry fwd_J (x : [8]f64) =
tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1))

entry rev_J (x : [8]f64) =
vjp obj x 1

0 comments on commit 7ed74ec

Please sign in to comment.