From bc24c74fb67722ef10346edcee3d9da5193a7b93 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 11 Dec 2024 15:14:50 +0100 Subject: [PATCH] Fix #2199. --- CHANGELOG.md | 2 ++ src/Language/Futhark/Interpreter.hs | 2 +- tests/ad/issue2199.fut | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 tests/ad/issue2199.fut diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a35ad6093..362687814b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * Inconsistent handling of types in lambda lifting (#2197). +* Invalid primal results from `vjp2` in interpreter (#2199). + ## [0.25.24] ### Added diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 009d9ac4e4..42e710cab7 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -2010,7 +2010,7 @@ initialCtx = let drvs = M.map (Just . putAD) $ M.unionsWith add $ map snd m -- Extract the output values, and the partial derivatives - let ov = modifyValue (\i _ -> fst $ m !! i) o + let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o let od = fromMaybe (error "vjp: differentiation failed") $ modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v diff --git a/tests/ad/issue2199.fut b/tests/ad/issue2199.fut new file mode 100644 index 0000000000..d1c038fd06 --- /dev/null +++ b/tests/ad/issue2199.fut @@ -0,0 +1,14 @@ +-- == +-- entry: test_primal test_rev +-- input { [1.0,2.0] [3.0,4.0] [5.0, 6.0] } +-- output { 3.0 7.0 11.0 } + +def op (x0, y0, z0) (x1, y1, z1) : (f64, f64, f64) = (x0 + x1, y0 + y1, z0 + z1) +def ne = (0f64, 0f64, 0f64) + +def primal xs = reduce_comm op ne xs + +entry test_primal as bs cs = primal (zip3 as bs cs) + +entry test_rev as bs cs = + (vjp2 (\(as, bs, cs) -> test_primal as bs cs) (as, bs, cs) (1, 1, 1)).0