Skip to content

Commit

Permalink
PLT-7745: Basic rewrite rules for builtins
Browse files Browse the repository at this point in the history
  • Loading branch information
bezirg committed Oct 12, 2023
1 parent 75568a1 commit 6d03234
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 5 deletions.
3 changes: 3 additions & 0 deletions plutus-core/plutus-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ library plutus-ir
PlutusIR.Transform.NonStrict
PlutusIR.Transform.RecSplit
PlutusIR.Transform.Rename
PlutusIR.Transform.Rewrite
PlutusIR.Transform.StrictifyBindings
PlutusIR.Transform.Substitute
PlutusIR.Transform.ThunkRecursions
Expand Down Expand Up @@ -583,6 +584,7 @@ test-suite plutus-ir-test
PlutusIR.Scoping.Tests
PlutusIR.Transform.Beta.Tests
PlutusIR.Transform.CommuteFnWithConst.Tests
PlutusIR.Transform.Rewrite.Tests
PlutusIR.Transform.DeadCode.Tests
PlutusIR.Transform.EvaluateBuiltins.Tests
PlutusIR.Transform.Inline.Tests
Expand Down Expand Up @@ -615,6 +617,7 @@ test-suite plutus-ir-test
, tasty-quickcheck
, text
, unordered-containers
, data-default-class

executable pir
import: lang
Expand Down
14 changes: 12 additions & 2 deletions plutus-core/plutus-ir/src/PlutusIR/Analysis/Builtins.hs
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE RankNTypes #-}
module PlutusIR.Analysis.Builtins where

import Data.Monoid
import Control.Lens.TH
import Data.Kind
import PlutusCore.Builtin
import PlutusCore.Builtin qualified as PLC
import PlutusPrelude (Default (..))
import PlutusIR qualified as PIR

newtype RewriteRules uni fun = RewriteRules {
unRewriteRule :: forall tyname name a. Semigroup a => Dual (Endo (PIR.Term tyname name uni fun a))
}

-- | All non-static information about builtins that the compiler might want.
data BuiltinsInfo (uni :: Type -> Type) fun = BuiltinsInfo
{ _biSemanticsVariant :: PLC.BuiltinSemanticsVariant fun
, _rewriteRules :: RewriteRules uni fun
}

makeLenses ''BuiltinsInfo

instance (Default (BuiltinSemanticsVariant fun)) => Default (BuiltinsInfo uni fun) where
def = BuiltinsInfo def
def = BuiltinsInfo
def
-- no rewrite rules by default (aka id).
(RewriteRules mempty)
7 changes: 6 additions & 1 deletion plutus-core/plutus-ir/src/PlutusIR/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import PlutusIR.Transform.LetMerge qualified as LetMerge
import PlutusIR.Transform.NonStrict qualified as NonStrict
import PlutusIR.Transform.RecSplit qualified as RecSplit
import PlutusIR.Transform.Rename ()
import PlutusIR.Transform.Rewrite qualified as Rewrite
import PlutusIR.Transform.StrictifyBindings qualified as StrictifyBindings
import PlutusIR.Transform.ThunkRecursions qualified as ThunkRec
import PlutusIR.Transform.Unwrap qualified as Unwrap
Expand Down Expand Up @@ -134,6 +135,9 @@ availablePasses =
binfo <- view ccBuiltinsInfo
Inline.inline hints binfo t
)
, Pass "rewrite rules" (onOption coDoSimplifierRewrite) (\ t -> do
binfo <- view ccBuiltinsInfo
pure $ Rewrite.userRewrite binfo t)
, Pass "commuteFnWithConst" (onOption coDoSimplifiercommuteFnWithConst) (pure . CommuteFnWithConst.commuteFnWithConst)
]

Expand All @@ -144,11 +148,12 @@ simplify
simplify = foldl' (>=>) pure (map applyPass availablePasses)

-- | Perform some simplification of a 'Term'.
--
-- NOTE: simplifyTerm requires at least 1 prior dead code elimination pass
simplifyTerm
:: forall m e uni fun a b. (Compiling m e uni fun a, b ~ Provenance a)
=> Term TyName Name uni fun b -> m (Term TyName Name uni fun b)
simplifyTerm = runIfOpts simplify'
-- NOTE: we need at least one pass of dead code elimination
where
simplify' :: Term TyName Name uni fun b -> m (Term TyName Name uni fun b)
simplify' t = do
Expand Down
2 changes: 2 additions & 0 deletions plutus-core/plutus-ir/src/PlutusIR/Compiler/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ data CompilationOpts a = CompilationOpts {
, _coDoSimplifierUnwrapCancel :: Bool
, _coDoSimplifierCaseReduce :: Bool
, _coDoSimplifiercommuteFnWithConst :: Bool
, _coDoSimplifierRewrite :: Bool
, _coDoSimplifierBeta :: Bool
, _coDoSimplifierInline :: Bool
, _coDoSimplifierKnownCon :: Bool
Expand All @@ -110,6 +111,7 @@ defaultCompilationOpts = CompilationOpts
, _coDoSimplifierUnwrapCancel = True
, _coDoSimplifierCaseReduce = True
, _coDoSimplifiercommuteFnWithConst = True
, _coDoSimplifierRewrite = True
, _coDoSimplifierKnownCon = True
, _coDoSimplifierBeta = True
, _coDoSimplifierInline = True
Expand Down
36 changes: 36 additions & 0 deletions plutus-core/plutus-ir/src/PlutusIR/Transform/Rewrite.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module PlutusIR.Transform.Rewrite where

import PlutusCore.Default
import PlutusIR.Analysis.Builtins
import PlutusIR
import Data.Monoid

import Control.Lens

userRewrite :: (Semigroup a, t ~ Term tyname name uni fun a)
=> BuiltinsInfo uni fun
-> t
-> t
userRewrite bi t =
let RewriteRules f = bi^.rewriteRules
in transformOf termSubterms (appEndo $ getDual f) t

defaultUniRewriteRules :: RewriteRules DefaultUni DefaultFun
defaultUniRewriteRules = RewriteRules $ combineRules
-- rules are applied from left to right because of Dual
[ decodeEncode
]
where
combineRules = foldMap (Dual . Endo)

decodeEncode :: Semigroup a => Term tyname name uni DefaultFun a -> Term tyname name uni DefaultFun a
decodeEncode = \case
BA DecodeUtf8 a1 a2 (BA EncodeUtf8 a3 a4 t) ->
-- place the missed annotations inside the rewritten term
(<> a1 <> a2 <> a3 <> a4) <$> t
t -> t

pattern BA b a1 a2 t <- Apply a1 (Builtin a2 b) t
22 changes: 22 additions & 0 deletions plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/Tests.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module PlutusIR.Transform.Rewrite.Tests where

import PlutusIR.Parser
import PlutusIR.Test
import PlutusIR.Analysis.Builtins
import PlutusIR.Transform.Rewrite qualified as Rewrite

import Data.Default.Class
import Control.Lens
import Test.Tasty
import Test.Tasty.Extras


test_commuteDefaultFun :: TestTree
test_commuteDefaultFun = runTestNestedIn ["plutus-ir/test/PlutusIR/Transform"] $
testNested "Rewrite" $
fmap
(goldenPir (Rewrite.userRewrite builtinsInfo) pTerm)
[ "decodeEncodeUtf8"
]
where
builtinsInfo = def & rewriteRules .~ Rewrite.defaultUniRewriteRules
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
(lam
x (con string) [ (builtin decodeUtf8) [ (builtin encodeUtf8)
[ (builtin decodeUtf8) [ (builtin encodeUtf8) x
] ] ] ]
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(lam x (con string) x)
6 changes: 4 additions & 2 deletions plutus-tx-plugin/src/PlutusTx/Plugin.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ import PlutusIR.Compiler.Provenance (noProvenance, original)
import Prettyprinter qualified as PP
import System.IO (openTempFile)
import System.IO.Unsafe (unsafePerformIO)
import PlutusIR.Analysis.Builtins
import PlutusIR.Transform.Rewrite

data PluginCtx = PluginCtx
{ pcOpts :: PluginOptions
Expand Down Expand Up @@ -406,7 +408,7 @@ compileMarkedExpr locStr codeTy origE = do
ccBlackholed = mempty,
ccCurDef = Nothing,
ccModBreaks = modBreaks,
ccBuiltinsInfo = def,
ccBuiltinsInfo = def & rewriteRules .~ defaultUniRewriteRules,
ccBuiltinCostModel = def,
ccDebugTraceOn = _posDumpCompilationTrace opts
}
Expand Down Expand Up @@ -502,7 +504,7 @@ runCompiler moduleName opts expr = do
(if plcVersion < PLC.plcVersion110
then PIR.ScottEncoding else PIR.SumsOfProducts)
-- TODO: ensure the same as the one used in the plugin
& set PIR.ccBuiltinsInfo def
& set PIR.ccBuiltinsInfo (def & rewriteRules .~ defaultUniRewriteRules)
& set PIR.ccBuiltinCostModel def
plcOpts = PLC.defaultCompilationOpts
& set (PLC.coSimplifyOpts . UPLC.soMaxSimplifierIterations)
Expand Down

0 comments on commit 6d03234

Please sign in to comment.