Skip to content

Commit

Permalink
Fix #2199.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Dec 11, 2024
1 parent c849119 commit bc24c74
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

* Inconsistent handling of types in lambda lifting (#2197).

* Invalid primal results from `vjp2` in interpreter (#2199).

## [0.25.24]

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/Language/Futhark/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,7 @@ initialCtx =
let drvs = M.map (Just . putAD) $ M.unionsWith add $ map snd m

-- Extract the output values, and the partial derivatives
let ov = modifyValue (\i _ -> fst $ m !! i) o
let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o
let od =
fromMaybe (error "vjp: differentiation failed") $
modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v
Expand Down
14 changes: 14 additions & 0 deletions tests/ad/issue2199.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- ==
-- entry: test_primal test_rev
-- input { [1.0,2.0] [3.0,4.0] [5.0, 6.0] }
-- output { 3.0 7.0 11.0 }

def op (x0, y0, z0) (x1, y1, z1) : (f64, f64, f64) = (x0 + x1, y0 + y1, z0 + z1)
def ne = (0f64, 0f64, 0f64)

def primal xs = reduce_comm op ne xs

entry test_primal as bs cs = primal (zip3 as bs cs)

entry test_rev as bs cs =
(vjp2 (\(as, bs, cs) -> test_primal as bs cs) (as, bs, cs) (1, 1, 1)).0

0 comments on commit bc24c74

Please sign in to comment.