From eca957a99fa4e9d42d3a0e7ecddfaae6060e9a6e Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 5 Dec 2024 17:20:42 +0100 Subject: [PATCH] Fix #2197. --- CHANGELOG.md | 2 ++ src/Futhark/Internalise/LiftLambdas.hs | 17 +++++++++++++---- tests/issue2197.fut | 8 ++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 tests/issue2197.fut diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ee11f2ccf..71a4800336 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ranges with start==end and negative step size, e.g. `1..0...1` produces `[1]` rather than an invalid range error. +* Inconsistent handling of types in lambda lifting (#2197). + ## [0.25.24] ### Added diff --git a/src/Futhark/Internalise/LiftLambdas.hs b/src/Futhark/Internalise/LiftLambdas.hs index 68532bd9e6..3e19eefbdc 100644 --- a/src/Futhark/Internalise/LiftLambdas.hs +++ b/src/Futhark/Internalise/LiftLambdas.hs @@ -6,6 +6,7 @@ module Futhark.Internalise.LiftLambdas (transformProg) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor +import Data.Bitraversable import Data.Foldable import Data.List (partition) import Data.Map.Strict qualified as M @@ -144,6 +145,12 @@ liftFunction fname tparams params (RetType dims ret) funbody = do transformSubExps :: ASTMapper LiftM transformSubExps = identityMapper {mapOnExp = transformExp} +transformType :: TypeBase Exp u -> LiftM (TypeBase Exp u) +transformType = bitraverse transformExp pure + +transformPat :: PatBase Info VName (TypeBase Exp u) -> LiftM (PatBase Info VName (TypeBase Exp u)) +transformPat = traverse transformType + transformExp :: Exp -> LiftM Exp transformExp (AppExp (LetFun fname (tparams, params, _, Info ret, funbody) body _) _) = do funbody' <- bindingParams (map typeParamName tparams) params $ transformExp funbody @@ -156,8 +163,9 @@ transformExp e@(Lambda params body _ (Info ret) _) = do liftFunction fname [] params ret body' <*> pure (typeOf e) transformExp (AppExp (LetPat sizes pat e body loc) appres) = do e' <- transformExp e - body' <- bindingLetPat (map sizeName sizes) pat $ transformExp body - pure $ AppExp (LetPat sizes pat e' body' loc) appres + pat' <- transformPat pat + body' <- bindingLetPat (map sizeName sizes) pat' $ transformExp body + pure $ AppExp (LetPat sizes pat' e' body' loc) appres transformExp (AppExp (Match e cases loc) appres) = do e' <- transformExp e cases' <- mapM transformCase cases @@ -173,10 +181,11 @@ transformExp (AppExp (Loop sizes pat args form body loc) appres) = do form' <- astMap transformSubExps form body' <- bindingForm form' $ transformExp body pure $ AppExp (Loop sizes pat (LoopInitExplicit args') form' body' loc) appres -transformExp e@(Var v (Info t) _) = +transformExp (Var v (Info t) loc) = do + t' <- transformType t -- Note that function-typed variables can only occur in expressions, -- not in other places where VNames/QualNames can occur. - asks $ maybe e ($ t) . M.lookup (qualLeaf v) . envReplace + asks $ maybe (Var v (Info t') loc) ($ t') . M.lookup (qualLeaf v) . envReplace transformExp e = astMap transformSubExps e transformValBind :: ValBind -> LiftM () diff --git a/tests/issue2197.fut b/tests/issue2197.fut new file mode 100644 index 0000000000..329ac72b0c --- /dev/null +++ b/tests/issue2197.fut @@ -0,0 +1,8 @@ +def index_of_first p xs = + loop i = 0 while i < length xs && !p xs[i] do i + 1 + +def span p xs = let i = index_of_first p xs in (take i xs, drop i xs) + +entry part1 [l] (ls: [][l]i32) = + let blank (l: [l]i32) = null l + in span blank ls |> \(x, y) -> (id x, tail y)