diff --git a/CHANGELOG.md b/CHANGELOG.md index dff6b1b5c1..12865289ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Corner case optimisation for mapping over `iota` (#1874). +* AD for certain combinations of `map` and indexing (#1878). + ## [0.23.1] ### Added diff --git a/src/Futhark/AD/Rev/Map.hs b/src/Futhark/AD/Rev/Map.hs index 47577266cb..ca6898f684 100644 --- a/src/Futhark/AD/Rev/Map.hs +++ b/src/Futhark/AD/Rev/Map.hs @@ -1,5 +1,7 @@ {-# LANGUAGE TypeFamilies #-} +-- | VJP transformation for Map SOACs. This is a pretty complicated +-- case due to the possibility of free variables. module Futhark.AD.Rev.Map (vjpMap) where import Control.Monad @@ -73,6 +75,8 @@ withAcc inputs m = do subAD $ mkLambda (cert_params ++ acc_params) $ m $ map paramName acc_params letTupExp "withhacc_res" $ WithAcc inputs acc_lam +-- | Perform VJP on a Map. The 'Adj' list is the adjoints of the +-- result of the map. vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM () vjpMap ops res_adjs _ w map_lam as | Just res_ivs <- mapM isSparse res_adjs = returnSweepCode $ do @@ -80,7 +84,7 @@ vjpMap ops res_adjs _ w map_lam as -- (length res_ivs), there is no need for the return sweep code to -- contain a Map at all. - free <- filterM isActive $ namesToList $ freeIn map_lam + free <- filterM isActive $ namesToList $ freeIn map_lam `namesSubtract` namesFromList as free_ts <- mapM lookupType free let adjs_for = map paramName (lambdaParams map_lam) ++ free adjs_ts = map paramType (lambdaParams map_lam) ++ free_ts @@ -96,7 +100,13 @@ vjpMap ops res_adjs _ w map_lam as forM_ (zip as adjs_ts) $ \(a, t) -> do scratch <- letSubExp "oo_scratch" =<< eBlank t updateAdjIndex a (OutOfBounds, adj_i) scratch - first subExpsRes . adjsReps <$> mapM lookupAdj as + -- We must make sure that all free variables have the same + -- representation in the oo-branch as in the ib-branch. + -- In practice we do this by manifesting the adjoint. + -- This is probably efficient, since the adjoint of a free + -- variable is probably either a scalar or an accumulator. + forM_ free $ \v -> insAdj v =<< adjVal =<< lookupAdj v + first subExpsRes . adjsReps <$> mapM lookupAdj (as <> free) inBounds res_i adj_i adj_v = subAD . buildRenamedBody $ do forM_ (zip (lambdaParams map_lam) as) $ \(p, a) -> do a_t <- lookupType a @@ -105,15 +115,19 @@ vjpMap ops res_adjs _ w map_lam as adj_elems <- fmap (map resSubExp) . bodyBind . lambdaBody =<< vjpLambda ops (oneHot res_i (AdjVal adj_v)) adjs_for map_lam - forM_ (zip as adj_elems) $ \(a, a_adj_elem) -> do + let (as_adj_elems, free_adj_elems) = splitAt (length as) adj_elems + forM_ (zip as as_adj_elems) $ \(a, a_adj_elem) -> updateAdjIndex a (AssumeBounds, adj_i) a_adj_elem - first subExpsRes . adjsReps <$> mapM lookupAdj as + forM_ (zip free free_adj_elems) $ \(v, adj_se) -> do + adj_se_v <- letExp "adj_v" (BasicOp $ SubExp adj_se) + insAdj v adj_se_v + first subExpsRes . adjsReps <$> mapM lookupAdj (as <> free) -- Generate an iteration of the map function for every -- position. This is a bit inefficient - probably we could do -- some deduplication. forPos res_i (check, adj_i, adj_v) = do - as_adj <- + adjs <- case check of CheckBounds b -> do (obbranch, mkadjs) <- ooBounds adj_i @@ -129,7 +143,7 @@ vjpMap ops res_adjs _ w map_lam as OutOfBounds -> mapM lookupAdj as - zipWithM setAdj as as_adj + zipWithM setAdj (as <> free) adjs -- Generate an iteration of the map function for every result. forRes res_i = mapM_ (forPos res_i) diff --git a/tests/ad/map6.fut b/tests/ad/map6.fut new file mode 100644 index 0000000000..b7a40aafa3 --- /dev/null +++ b/tests/ad/map6.fut @@ -0,0 +1,29 @@ +-- #1878 +-- == +-- entry: fwd_J rev_J +-- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } +-- output { [[0.0, 2.0, 3.0, 4.0], +-- [0.0, 0.0, 1.0, 1.0], +-- [0.0, 0.0, 0.0, 1.0], +-- [0.0, 0.0, 0.0, 0.0], +-- [-4.0, -6.0, -7.0, -8.0], +-- [0.0, 0.0, -1.0, -1.0], +-- [0.0, 0.0, 0.0, -1.0], +-- [0.0, 0.0, 0.0, 0.0]] +-- } + +def obj (x : [8]f64) = + #[unsafe] -- For simplicity of generated code. + let col_w_pre_red = + tabulate_3d 4 2 4 (\ k i j -> x[k+j]*x[i+j]) + let col_w_red = + map (map f64.sum) col_w_pre_red + let col_eq : [4]f64 = + map (\w -> w[0] - w[1]) col_w_red + in col_eq + +entry fwd_J (x : [8]f64) = + tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) + +entry rev_J (x : [8]f64) = + transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) diff --git a/tests/ad/map7.fut b/tests/ad/map7.fut new file mode 100644 index 0000000000..01f5c3e249 --- /dev/null +++ b/tests/ad/map7.fut @@ -0,0 +1,22 @@ +-- #1878. The interesting thing here is that the sparse adjoint also +-- has active free variables. +-- == +-- entry: fwd_J rev_J +-- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } +-- output { [0.0, 0.0, 0.0, 0.0, -4.0, 0.0, 0.0, 0.0] } + +def obj (x : [8]f64) = + #[unsafe] -- For simplicity of generated code. + let col_w_pre_red = + tabulate_3d 4 2 4 (\ k i j -> x[k+j]*x[i+j]) + let col_w_red = + map (map f64.sum) col_w_pre_red + let col_eq : [4]f64 = + map (\w -> w[0] - w[1]) col_w_red + in f64.maximum col_eq + +entry fwd_J (x : [8]f64) = + tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) + +entry rev_J (x : [8]f64) = + vjp obj x 1