diff --git a/plugins/hls-tactics-plugin/hls-tactics-plugin.cabal b/plugins/hls-tactics-plugin/hls-tactics-plugin.cabal index 393d1f8cd4..f6e6b8fd4e 100644 --- a/plugins/hls-tactics-plugin/hls-tactics-plugin.cabal +++ b/plugins/hls-tactics-plugin/hls-tactics-plugin.cabal @@ -25,6 +25,7 @@ library Ide.Plugin.Tactic Ide.Plugin.Tactic.Auto Ide.Plugin.Tactic.CodeGen + Ide.Plugin.Tactic.CodeGen.Utils Ide.Plugin.Tactic.Context Ide.Plugin.Tactic.Debug Ide.Plugin.Tactic.GHC @@ -34,6 +35,7 @@ library Ide.Plugin.Tactic.Machinery Ide.Plugin.Tactic.Naming Ide.Plugin.Tactic.Range + Ide.Plugin.Tactic.Simplify Ide.Plugin.Tactic.Tactics Ide.Plugin.Tactic.Types Ide.Plugin.Tactic.TestTypes diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs index 5182161f25..a261080aab 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic.hs @@ -327,6 +327,7 @@ tacticCmd tac lf state (TacticParams uri range var_name) $ ResponseError InvalidRequest (T.pack $ show err) Nothing Right rtr -> do traceMX "solns" $ rtr_other_solns rtr + traceMX "after simplification" $ rtr_extract rtr let g = graft (RealSrcSpan span) $ rtr_extract rtr response = transform dflags (clientCapabilities lf) uri g pm pure $ case response of diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs index 1cab232a7a..029eb971d3 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs @@ -1,10 +1,13 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} -module Ide.Plugin.Tactic.CodeGen where +module Ide.Plugin.Tactic.CodeGen + ( module Ide.Plugin.Tactic.CodeGen + , module Ide.Plugin.Tactic.CodeGen.Utils + ) where import Control.Lens ((+~), (%~), (<>~)) import Control.Monad.Except @@ -18,7 +21,6 @@ import Data.Traversable import DataCon import Development.IDE.GHC.Compat import GHC.Exts -import GHC.SourceGen (RdrNameStr) import GHC.SourceGen.Binds import GHC.SourceGen.Expr import GHC.SourceGen.Overloaded @@ -28,7 +30,7 @@ import Ide.Plugin.Tactic.Judgements import Ide.Plugin.Tactic.Machinery import Ide.Plugin.Tactic.Naming import Ide.Plugin.Tactic.Types -import Name +import Ide.Plugin.Tactic.CodeGen.Utils import Type hiding (Var) @@ -202,57 +204,3 @@ buildDataCon jdg dc apps = do . (rose (show dc) $ pure tr,) $ mkCon dc sgs - -mkCon :: DataCon -> [LHsExpr GhcPs] -> LHsExpr GhcPs -mkCon dcon (fmap unLoc -> args) - | isTupleDataCon dcon = - noLoc $ tuple args - | dataConIsInfix dcon - , (lhs : rhs : args') <- args = - noLoc $ foldl' (@@) (op lhs (coerceName dcon_name) rhs) args' - | otherwise = - noLoc $ foldl' (@@) (bvar' $ occName dcon_name) args - where - dcon_name = dataConName dcon - - - -coerceName :: HasOccName a => a -> RdrNameStr -coerceName = fromString . occNameString . occName - - - ------------------------------------------------------------------------------- --- | Like 'var', but works over standard GHC 'OccName's. -var' :: Var a => OccName -> a -var' = var . fromString . occNameString - - ------------------------------------------------------------------------------- --- | Like 'bvar', but works over standard GHC 'OccName's. -bvar' :: BVar a => OccName -> a -bvar' = bvar . fromString . occNameString - - ------------------------------------------------------------------------------- --- | Get an HsExpr corresponding to a function name. -mkFunc :: String -> HsExpr GhcPs -mkFunc = var' . mkVarOcc - - ------------------------------------------------------------------------------- --- | Get an HsExpr corresponding to a value name. -mkVal :: String -> HsExpr GhcPs -mkVal = var' . mkVarOcc - - ------------------------------------------------------------------------------- --- | Like 'op', but easier to call. -infixCall :: String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs -infixCall s = flip op (fromString s) - - ------------------------------------------------------------------------------- --- | Like '(@@)', but uses a dollar instead of parentheses. -appDollar :: HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs -appDollar = infixCall "$" diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen/Utils.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen/Utils.hs new file mode 100644 index 0000000000..e3551cc660 --- /dev/null +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen/Utils.hs @@ -0,0 +1,67 @@ +{-# LANGUAGE ViewPatterns #-} + +module Ide.Plugin.Tactic.CodeGen.Utils where + +import Data.List +import DataCon +import Development.IDE.GHC.Compat +import GHC.Exts +import GHC.SourceGen (RdrNameStr) +import GHC.SourceGen.Overloaded +import Name + + +------------------------------------------------------------------------------ +-- | Make a data constructor with the given arguments. +mkCon :: DataCon -> [LHsExpr GhcPs] -> LHsExpr GhcPs +mkCon dcon (fmap unLoc -> args) + | isTupleDataCon dcon = + noLoc $ tuple args + | dataConIsInfix dcon + , (lhs : rhs : args') <- args = + noLoc $ foldl' (@@) (op lhs (coerceName dcon_name) rhs) args' + | otherwise = + noLoc $ foldl' (@@) (bvar' $ occName dcon_name) args + where + dcon_name = dataConName dcon + + +coerceName :: HasOccName a => a -> RdrNameStr +coerceName = fromString . occNameString . occName + + +------------------------------------------------------------------------------ +-- | Like 'var', but works over standard GHC 'OccName's. +var' :: Var a => OccName -> a +var' = var . fromString . occNameString + + +------------------------------------------------------------------------------ +-- | Like 'bvar', but works over standard GHC 'OccName's. +bvar' :: BVar a => OccName -> a +bvar' = bvar . fromString . occNameString + + +------------------------------------------------------------------------------ +-- | Get an HsExpr corresponding to a function name. +mkFunc :: String -> HsExpr GhcPs +mkFunc = var' . mkVarOcc + + +------------------------------------------------------------------------------ +-- | Get an HsExpr corresponding to a value name. +mkVal :: String -> HsExpr GhcPs +mkVal = var' . mkVarOcc + + +------------------------------------------------------------------------------ +-- | Like 'op', but easier to call. +infixCall :: String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs +infixCall s = flip op (fromString s) + + +------------------------------------------------------------------------------ +-- | Like '(@@)', but uses a dollar instead of parentheses. +appDollar :: HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs +appDollar = infixCall "$" + diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/GHC.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/GHC.hs index 5cba1d20b6..efe715d12c 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/GHC.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/GHC.hs @@ -112,13 +112,17 @@ lambdaCaseable (splitFunTy_maybe -> Just (arg, res)) = Just $ isJust $ algebraicTyCon res lambdaCaseable _ = Nothing -fromPatCompat :: PatCompat GhcTc -> Pat GhcTc +-- It's hard to generalize over these since weird type families are involved. +fromPatCompatTc :: PatCompat GhcTc -> Pat GhcTc +fromPatCompatPs :: PatCompat GhcPs -> Pat GhcPs #if __GLASGOW_HASKELL__ == 808 type PatCompat pass = Pat pass -fromPatCompat = id +fromPatCompatTc = id +fromPatCompatPs = id #else type PatCompat pass = LPat pass -fromPatCompat = unLoc +fromPatCompatTc = unLoc +fromPatCompatPs = unLoc #endif ------------------------------------------------------------------------------ @@ -132,7 +136,7 @@ pattern TopLevelRHS name ps body <- [L _ (GRHS _ [] body)] _) getPatName :: PatCompat GhcTc -> Maybe OccName -getPatName (fromPatCompat -> p0) = +getPatName (fromPatCompatTc -> p0) = case p0 of VarPat _ x -> Just $ occName $ unLoc x LazyPat _ p -> getPatName p diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs index dd307da2ca..787fb6bb7d 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs @@ -33,6 +33,7 @@ import Data.Set (Set) import qualified Data.Set as S import Development.IDE.GHC.Compat import Ide.Plugin.Tactic.Judgements +import Ide.Plugin.Tactic.Simplify (simplify) import Ide.Plugin.Tactic.Types import OccName (HasOccName(occName)) import Refinery.ProofState @@ -97,7 +98,7 @@ runTactic ctx jdg t = case sorted of (((tr, ext), _) : _) -> Right - . RunTacticResults tr ext + . RunTacticResults tr (simplify ext) . reverse . fmap fst $ take 5 sorted diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Simplify.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Simplify.hs new file mode 100644 index 0000000000..c125d50876 --- /dev/null +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Simplify.hs @@ -0,0 +1,122 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE ViewPatterns #-} + +module Ide.Plugin.Tactic.Simplify + ( simplify + ) where + +import Data.Data (Data) +import Data.Generics (everywhere, somewhere, something, listify, extT, mkT, GenericT, mkQ) +import Data.List.Extra (unsnoc) +import Data.Maybe (isJust) +import Data.Monoid (Endo (..)) +import Development.IDE.GHC.Compat +import GHC.Exts (fromString) +import GHC.SourceGen (var, op) +import GHC.SourceGen.Expr (lambda) +import Ide.Plugin.Tactic.CodeGen.Utils +import Ide.Plugin.Tactic.GHC (fromPatCompatPs) + + +------------------------------------------------------------------------------ +-- | A pattern over the otherwise (extremely) messy AST for lambdas. +pattern Lambda :: [Pat GhcPs] -> HsExpr GhcPs -> HsExpr GhcPs +pattern Lambda pats body <- + HsLam _ + (MG {mg_alts = L _ [L _ + (Match { m_pats = fmap fromPatCompatPs -> pats + , m_grhss = GRHSs {grhssGRHSs = [L _ ( + GRHS _ [] (L _ body))]} + })]}) + where + -- If there are no patterns to bind, just stick in the body + Lambda [] body = body + Lambda pats body = lambda pats body + + +------------------------------------------------------------------------------ +-- | Simlify an expression. +simplify :: LHsExpr GhcPs -> LHsExpr GhcPs +simplify + = head + . drop 3 -- Do three passes; this should be good enough for the limited + -- amount of gas we give to auto + . iterate (everywhere $ foldEndo + [ simplifyEtaReduce + , simplifyRemoveParens + , simplifyCompose + ]) + + +------------------------------------------------------------------------------ +-- | Like 'foldMap' but for endomorphisms. +foldEndo :: Foldable t => t (a -> a) -> a -> a +foldEndo = appEndo . foldMap Endo + + +------------------------------------------------------------------------------ +-- | Does this thing contain any references to 'HsVar's with the given +-- 'RdrName'? +containsHsVar :: Data a => RdrName -> a -> Bool +containsHsVar name x = not $ null $ listify ( + \case + ((HsVar _ (L _ a)) :: HsExpr GhcPs) | a == name -> True + _ -> False + ) x + + +------------------------------------------------------------------------------ +-- | Perform an eta reduction. For example, transforms @\x -> (f g) x@ into +-- @f g@. +simplifyEtaReduce :: GenericT +simplifyEtaReduce = mkT $ \case + Lambda + [VarPat _ (L _ pat)] + (HsVar _ (L _ a)) | pat == a -> + var "id" + Lambda + (unsnoc -> Just (pats, (VarPat _ (L _ pat)))) + (HsApp _ (L _ f) (L _ (HsVar _ (L _ a)))) + | pat == a + -- We can only perform this simplifiation if @pat@ is otherwise unused. + , not (containsHsVar pat f) -> + Lambda pats f + x -> x + + +------------------------------------------------------------------------------ +-- | Perform an eta-reducing function composition. For example, transforms +-- @\x -> f (g (h x))@ into @f . g . h@. +simplifyCompose :: GenericT +simplifyCompose = mkT $ \case + Lambda + (unsnoc -> Just (pats, (VarPat _ (L _ pat)))) + (unroll -> (fs@(_:_), (HsVar _ (L _ a)))) + | pat == a + -- We can only perform this simplifiation if @pat@ is otherwise unused. + , not (containsHsVar pat fs) -> + Lambda pats (foldr1 (infixCall ".") fs) + x -> x + + +------------------------------------------------------------------------------ +-- | Removes unnecessary parentheses on any token that doesn't need them. +simplifyRemoveParens :: GenericT +simplifyRemoveParens = mkT $ \case + HsPar _ (L _ x) | isAtomicHsExpr x -> x + (x :: HsExpr GhcPs) -> x + + +------------------------------------------------------------------------------ +-- | Unrolls a right-associative function application of the form +-- @HsApp f (HsApp g (HsApp h x))@ into @([f, g, h], x)@. +unroll :: HsExpr GhcPs -> ([HsExpr GhcPs], HsExpr GhcPs) +unroll (HsPar _ (L _ x)) = unroll x +unroll (HsApp _ (L _ f) (L _ a)) = + let (fs, r) = unroll a + in (f : fs, r) +unroll x = ([], x) + diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs index ac0ab3dff1..a60049de48 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs @@ -70,6 +70,9 @@ instance Show DataCon where instance Show Class where show = unsafeRender +instance Show (HsExpr GhcPs) where + show = unsafeRender + ------------------------------------------------------------------------------ data TacticState = TacticState diff --git a/test/functional/Tactic.hs b/test/functional/Tactic.hs index 6e33a96a90..d46dc8ff29 100644 --- a/test/functional/Tactic.hs +++ b/test/functional/Tactic.hs @@ -117,6 +117,9 @@ tests = testGroup , expectFail "GoldenFish.hs" 5 18 Auto "" , goldenTest "GoldenArbitrary.hs" 25 13 Auto "" , goldenTest "FmapBoth.hs" 2 12 Auto "" + , goldenTest "FmapJoin.hs" 2 14 Auto "" + , goldenTest "Fgmap.hs" 2 9 Auto "" + , goldenTest "FmapJoinInLet.hs" 4 19 Auto "" ] diff --git a/test/testdata/tactic/Fgmap.hs b/test/testdata/tactic/Fgmap.hs new file mode 100644 index 0000000000..de1968474e --- /dev/null +++ b/test/testdata/tactic/Fgmap.hs @@ -0,0 +1,2 @@ +fgmap :: (Functor f, Functor g) => (a -> b) -> (f (g a) -> f (g b)) +fgmap = _ diff --git a/test/testdata/tactic/Fgmap.hs.expected b/test/testdata/tactic/Fgmap.hs.expected new file mode 100644 index 0000000000..8c0b9a2f4a --- /dev/null +++ b/test/testdata/tactic/Fgmap.hs.expected @@ -0,0 +1,2 @@ +fgmap :: (Functor f, Functor g) => (a -> b) -> (f (g a) -> f (g b)) +fgmap = (fmap . fmap) diff --git a/test/testdata/tactic/FmapJoin.hs b/test/testdata/tactic/FmapJoin.hs new file mode 100644 index 0000000000..98a40133ea --- /dev/null +++ b/test/testdata/tactic/FmapJoin.hs @@ -0,0 +1,2 @@ +fJoin :: (Monad m, Monad f) => f (m (m a)) -> f (m a) +fJoin = fmap _ diff --git a/test/testdata/tactic/FmapJoin.hs.expected b/test/testdata/tactic/FmapJoin.hs.expected new file mode 100644 index 0000000000..8064301c89 --- /dev/null +++ b/test/testdata/tactic/FmapJoin.hs.expected @@ -0,0 +1,2 @@ +fJoin :: (Monad m, Monad f) => f (m (m a)) -> f (m a) +fJoin = fmap (\ mma -> (>>=) mma id) diff --git a/test/testdata/tactic/FmapJoinInLet.hs b/test/testdata/tactic/FmapJoinInLet.hs new file mode 100644 index 0000000000..e6fe6cbd0d --- /dev/null +++ b/test/testdata/tactic/FmapJoinInLet.hs @@ -0,0 +1,4 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +fJoin :: forall f m a. (Monad m, Monad f) => f (m (m a)) -> f (m a) +fJoin = let f = (_ :: m (m a) -> m a) in fmap f diff --git a/test/testdata/tactic/FmapJoinInLet.hs.expected b/test/testdata/tactic/FmapJoinInLet.hs.expected new file mode 100644 index 0000000000..a9a9f04f9e --- /dev/null +++ b/test/testdata/tactic/FmapJoinInLet.hs.expected @@ -0,0 +1,4 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +fJoin :: forall f m a. (Monad m, Monad f) => f (m (m a)) -> f (m a) +fJoin = let f = ( (\ mma -> (>>=) mma id) :: m (m a) -> m a) in fmap f diff --git a/test/testdata/tactic/GoldenIdTypeFam.hs.expected b/test/testdata/tactic/GoldenIdTypeFam.hs.expected index ad5697334e..7b3d1beda0 100644 --- a/test/testdata/tactic/GoldenIdTypeFam.hs.expected +++ b/test/testdata/tactic/GoldenIdTypeFam.hs.expected @@ -4,4 +4,4 @@ type family TyFam type instance TyFam = Int tyblah' :: TyFam -> Int -tyblah' = (\ i -> i) +tyblah' = id diff --git a/test/testdata/tactic/GoldenShowCompose.hs.expected b/test/testdata/tactic/GoldenShowCompose.hs.expected index 373ea6af91..e672cc6a02 100644 --- a/test/testdata/tactic/GoldenShowCompose.hs.expected +++ b/test/testdata/tactic/GoldenShowCompose.hs.expected @@ -1,2 +1,2 @@ showCompose :: Show a => (b -> a) -> b -> String -showCompose = (\ fba b -> show (fba b)) +showCompose = (\ fba -> show . fba)