From 6d032348275f62a128927ae3b82f0b6a8eeb7ec9 Mon Sep 17 00:00:00 2001 From: Nikolaos Bezirgiannis Date: Wed, 11 Oct 2023 11:48:49 +0200 Subject: [PATCH] PLT-7745: Basic rewrite rules for builtins --- plutus-core/plutus-core.cabal | 3 ++ .../src/PlutusIR/Analysis/Builtins.hs | 14 ++++++-- .../plutus-ir/src/PlutusIR/Compiler.hs | 7 +++- .../plutus-ir/src/PlutusIR/Compiler/Types.hs | 2 ++ .../src/PlutusIR/Transform/Rewrite.hs | 36 +++++++++++++++++++ .../test/PlutusIR/Transform/Rewrite/Tests.hs | 22 ++++++++++++ .../Transform/Rewrite/decodeEncodeUtf8 | 5 +++ .../Transform/Rewrite/decodeEncodeUtf8.golden | 1 + plutus-tx-plugin/src/PlutusTx/Plugin.hs | 6 ++-- 9 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 plutus-core/plutus-ir/src/PlutusIR/Transform/Rewrite.hs create mode 100644 plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/Tests.hs create mode 100644 plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8 create mode 100644 plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8.golden diff --git a/plutus-core/plutus-core.cabal b/plutus-core/plutus-core.cabal index 591daa1f691..7483e6d8be9 100644 --- a/plutus-core/plutus-core.cabal +++ b/plutus-core/plutus-core.cabal @@ -510,6 +510,7 @@ library plutus-ir PlutusIR.Transform.NonStrict PlutusIR.Transform.RecSplit PlutusIR.Transform.Rename + PlutusIR.Transform.Rewrite PlutusIR.Transform.StrictifyBindings PlutusIR.Transform.Substitute PlutusIR.Transform.ThunkRecursions @@ -583,6 +584,7 @@ test-suite plutus-ir-test PlutusIR.Scoping.Tests PlutusIR.Transform.Beta.Tests PlutusIR.Transform.CommuteFnWithConst.Tests + PlutusIR.Transform.Rewrite.Tests PlutusIR.Transform.DeadCode.Tests PlutusIR.Transform.EvaluateBuiltins.Tests PlutusIR.Transform.Inline.Tests @@ -615,6 +617,7 @@ test-suite plutus-ir-test , tasty-quickcheck , text , unordered-containers + , data-default-class executable pir import: lang diff --git a/plutus-core/plutus-ir/src/PlutusIR/Analysis/Builtins.hs b/plutus-core/plutus-ir/src/PlutusIR/Analysis/Builtins.hs index f2a5df83d5a..8061b4b20fa 100644 --- a/plutus-core/plutus-ir/src/PlutusIR/Analysis/Builtins.hs +++ b/plutus-core/plutus-ir/src/PlutusIR/Analysis/Builtins.hs @@ -1,19 +1,29 @@ {-# LANGUAGE KindSignatures #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE RankNTypes #-} module PlutusIR.Analysis.Builtins where +import Data.Monoid import Control.Lens.TH import Data.Kind import PlutusCore.Builtin import PlutusCore.Builtin qualified as PLC import PlutusPrelude (Default (..)) +import PlutusIR qualified as PIR + +newtype RewriteRules uni fun = RewriteRules { + unRewriteRule :: forall tyname name a. Semigroup a => Dual (Endo (PIR.Term tyname name uni fun a)) + } -- | All non-static information about builtins that the compiler might want. data BuiltinsInfo (uni :: Type -> Type) fun = BuiltinsInfo { _biSemanticsVariant :: PLC.BuiltinSemanticsVariant fun + , _rewriteRules :: RewriteRules uni fun } - makeLenses ''BuiltinsInfo instance (Default (BuiltinSemanticsVariant fun)) => Default (BuiltinsInfo uni fun) where - def = BuiltinsInfo def + def = BuiltinsInfo + def + -- no rewrite rules by default (aka id). + (RewriteRules mempty) diff --git a/plutus-core/plutus-ir/src/PlutusIR/Compiler.hs b/plutus-core/plutus-ir/src/PlutusIR/Compiler.hs index b8f80eb1537..fb857f41aaa 100644 --- a/plutus-core/plutus-ir/src/PlutusIR/Compiler.hs +++ b/plutus-core/plutus-ir/src/PlutusIR/Compiler.hs @@ -66,6 +66,7 @@ import PlutusIR.Transform.LetMerge qualified as LetMerge import PlutusIR.Transform.NonStrict qualified as NonStrict import PlutusIR.Transform.RecSplit qualified as RecSplit import PlutusIR.Transform.Rename () +import PlutusIR.Transform.Rewrite qualified as Rewrite import PlutusIR.Transform.StrictifyBindings qualified as StrictifyBindings import PlutusIR.Transform.ThunkRecursions qualified as ThunkRec import PlutusIR.Transform.Unwrap qualified as Unwrap @@ -134,6 +135,9 @@ availablePasses = binfo <- view ccBuiltinsInfo Inline.inline hints binfo t ) + , Pass "rewrite rules" (onOption coDoSimplifierRewrite) (\ t -> do + binfo <- view ccBuiltinsInfo + pure $ Rewrite.userRewrite binfo t) , Pass "commuteFnWithConst" (onOption coDoSimplifiercommuteFnWithConst) (pure . CommuteFnWithConst.commuteFnWithConst) ] @@ -144,11 +148,12 @@ simplify simplify = foldl' (>=>) pure (map applyPass availablePasses) -- | Perform some simplification of a 'Term'. +-- +-- NOTE: simplifyTerm requires at least 1 prior dead code elimination pass simplifyTerm :: forall m e uni fun a b. (Compiling m e uni fun a, b ~ Provenance a) => Term TyName Name uni fun b -> m (Term TyName Name uni fun b) simplifyTerm = runIfOpts simplify' - -- NOTE: we need at least one pass of dead code elimination where simplify' :: Term TyName Name uni fun b -> m (Term TyName Name uni fun b) simplify' t = do diff --git a/plutus-core/plutus-ir/src/PlutusIR/Compiler/Types.hs b/plutus-core/plutus-ir/src/PlutusIR/Compiler/Types.hs index e29cccd805a..37f00492248 100644 --- a/plutus-core/plutus-ir/src/PlutusIR/Compiler/Types.hs +++ b/plutus-core/plutus-ir/src/PlutusIR/Compiler/Types.hs @@ -84,6 +84,7 @@ data CompilationOpts a = CompilationOpts { , _coDoSimplifierUnwrapCancel :: Bool , _coDoSimplifierCaseReduce :: Bool , _coDoSimplifiercommuteFnWithConst :: Bool + , _coDoSimplifierRewrite :: Bool , _coDoSimplifierBeta :: Bool , _coDoSimplifierInline :: Bool , _coDoSimplifierKnownCon :: Bool @@ -110,6 +111,7 @@ defaultCompilationOpts = CompilationOpts , _coDoSimplifierUnwrapCancel = True , _coDoSimplifierCaseReduce = True , _coDoSimplifiercommuteFnWithConst = True + , _coDoSimplifierRewrite = True , _coDoSimplifierKnownCon = True , _coDoSimplifierBeta = True , _coDoSimplifierInline = True diff --git a/plutus-core/plutus-ir/src/PlutusIR/Transform/Rewrite.hs b/plutus-core/plutus-ir/src/PlutusIR/Transform/Rewrite.hs new file mode 100644 index 00000000000..385ceab47b8 --- /dev/null +++ b/plutus-core/plutus-ir/src/PlutusIR/Transform/Rewrite.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} +module PlutusIR.Transform.Rewrite where + +import PlutusCore.Default +import PlutusIR.Analysis.Builtins +import PlutusIR +import Data.Monoid + +import Control.Lens + +userRewrite :: (Semigroup a, t ~ Term tyname name uni fun a) + => BuiltinsInfo uni fun + -> t + -> t +userRewrite bi t = + let RewriteRules f = bi^.rewriteRules + in transformOf termSubterms (appEndo $ getDual f) t + +defaultUniRewriteRules :: RewriteRules DefaultUni DefaultFun +defaultUniRewriteRules = RewriteRules $ combineRules + -- rules are applied from left to right because of Dual + [ decodeEncode + ] + where + combineRules = foldMap (Dual . Endo) + +decodeEncode :: Semigroup a => Term tyname name uni DefaultFun a -> Term tyname name uni DefaultFun a +decodeEncode = \case + BA DecodeUtf8 a1 a2 (BA EncodeUtf8 a3 a4 t) -> + -- place the missed annotations inside the rewritten term + (<> a1 <> a2 <> a3 <> a4) <$> t + t -> t + +pattern BA b a1 a2 t <- Apply a1 (Builtin a2 b) t diff --git a/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/Tests.hs b/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/Tests.hs new file mode 100644 index 00000000000..ae8eb2aca10 --- /dev/null +++ b/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/Tests.hs @@ -0,0 +1,22 @@ +module PlutusIR.Transform.Rewrite.Tests where + +import PlutusIR.Parser +import PlutusIR.Test +import PlutusIR.Analysis.Builtins +import PlutusIR.Transform.Rewrite qualified as Rewrite + +import Data.Default.Class +import Control.Lens +import Test.Tasty +import Test.Tasty.Extras + + +test_commuteDefaultFun :: TestTree +test_commuteDefaultFun = runTestNestedIn ["plutus-ir/test/PlutusIR/Transform"] $ + testNested "Rewrite" $ + fmap + (goldenPir (Rewrite.userRewrite builtinsInfo) pTerm) + [ "decodeEncodeUtf8" + ] + where + builtinsInfo = def & rewriteRules .~ Rewrite.defaultUniRewriteRules diff --git a/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8 b/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8 new file mode 100644 index 00000000000..dd010c621fc --- /dev/null +++ b/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8 @@ -0,0 +1,5 @@ +(lam + x (con string) [ (builtin decodeUtf8) [ (builtin encodeUtf8) + [ (builtin decodeUtf8) [ (builtin encodeUtf8) x + ] ] ] ] +) \ No newline at end of file diff --git a/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8.golden b/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8.golden new file mode 100644 index 00000000000..0775693c720 --- /dev/null +++ b/plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/decodeEncodeUtf8.golden @@ -0,0 +1 @@ +(lam x (con string) x) \ No newline at end of file diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin.hs b/plutus-tx-plugin/src/PlutusTx/Plugin.hs index 0bee0463654..7b76dcec9e2 100644 --- a/plutus-tx-plugin/src/PlutusTx/Plugin.hs +++ b/plutus-tx-plugin/src/PlutusTx/Plugin.hs @@ -83,6 +83,8 @@ import PlutusIR.Compiler.Provenance (noProvenance, original) import Prettyprinter qualified as PP import System.IO (openTempFile) import System.IO.Unsafe (unsafePerformIO) +import PlutusIR.Analysis.Builtins +import PlutusIR.Transform.Rewrite data PluginCtx = PluginCtx { pcOpts :: PluginOptions @@ -406,7 +408,7 @@ compileMarkedExpr locStr codeTy origE = do ccBlackholed = mempty, ccCurDef = Nothing, ccModBreaks = modBreaks, - ccBuiltinsInfo = def, + ccBuiltinsInfo = def & rewriteRules .~ defaultUniRewriteRules, ccBuiltinCostModel = def, ccDebugTraceOn = _posDumpCompilationTrace opts } @@ -502,7 +504,7 @@ runCompiler moduleName opts expr = do (if plcVersion < PLC.plcVersion110 then PIR.ScottEncoding else PIR.SumsOfProducts) -- TODO: ensure the same as the one used in the plugin - & set PIR.ccBuiltinsInfo def + & set PIR.ccBuiltinsInfo (def & rewriteRules .~ defaultUniRewriteRules) & set PIR.ccBuiltinCostModel def plcOpts = PLC.defaultCompilationOpts & set (PLC.coSimplifyOpts . UPLC.soMaxSimplifierIterations)