Skip to content

Commit

Permalink
Make CommuteFnWithConst transformation a RewriteRule
Browse files Browse the repository at this point in the history
  • Loading branch information
bezirg committed Oct 12, 2023
1 parent 6d03234 commit 2b41633
Show file tree
Hide file tree
Showing 19 changed files with 111 additions and 104 deletions.
8 changes: 4 additions & 4 deletions plutus-core/plutus-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ library plutus-ir
PlutusIR.Subst
PlutusIR.Transform.Beta
PlutusIR.Transform.CaseReduce
PlutusIR.Transform.CommuteFnWithConst
PlutusIR.Transform.DeadCode
PlutusIR.Transform.EvaluateBuiltins
PlutusIR.Transform.Inline.CallSiteInline
Expand All @@ -510,7 +509,9 @@ library plutus-ir
PlutusIR.Transform.NonStrict
PlutusIR.Transform.RecSplit
PlutusIR.Transform.Rename
PlutusIR.Transform.Rewrite
PlutusIR.Transform.RewriteRules
PlutusIR.Transform.RewriteRules.CommuteFnWithConst
PlutusIR.Transform.RewriteRules.DecodeEncodeUtf8
PlutusIR.Transform.StrictifyBindings
PlutusIR.Transform.Substitute
PlutusIR.Transform.ThunkRecursions
Expand Down Expand Up @@ -583,8 +584,7 @@ test-suite plutus-ir-test
PlutusIR.Purity.Tests
PlutusIR.Scoping.Tests
PlutusIR.Transform.Beta.Tests
PlutusIR.Transform.CommuteFnWithConst.Tests
PlutusIR.Transform.Rewrite.Tests
PlutusIR.Transform.RewriteRules.Tests
PlutusIR.Transform.DeadCode.Tests
PlutusIR.Transform.EvaluateBuiltins.Tests
PlutusIR.Transform.Inline.Tests
Expand Down
6 changes: 2 additions & 4 deletions plutus-core/plutus-ir/src/PlutusIR/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import PlutusIR.Compiler.Types
import PlutusIR.Error
import PlutusIR.Transform.Beta qualified as Beta
import PlutusIR.Transform.CaseReduce qualified as CaseReduce
import PlutusIR.Transform.CommuteFnWithConst qualified as CommuteFnWithConst
import PlutusIR.Transform.DeadCode qualified as DeadCode
import PlutusIR.Transform.EvaluateBuiltins qualified as EvaluateBuiltins
import PlutusIR.Transform.Inline.Inline qualified as Inline
Expand All @@ -66,7 +65,7 @@ import PlutusIR.Transform.LetMerge qualified as LetMerge
import PlutusIR.Transform.NonStrict qualified as NonStrict
import PlutusIR.Transform.RecSplit qualified as RecSplit
import PlutusIR.Transform.Rename ()
import PlutusIR.Transform.Rewrite qualified as Rewrite
import PlutusIR.Transform.RewriteRules qualified as RewriteRules
import PlutusIR.Transform.StrictifyBindings qualified as StrictifyBindings
import PlutusIR.Transform.ThunkRecursions qualified as ThunkRec
import PlutusIR.Transform.Unwrap qualified as Unwrap
Expand Down Expand Up @@ -137,8 +136,7 @@ availablePasses =
)
, Pass "rewrite rules" (onOption coDoSimplifierRewrite) (\ t -> do
binfo <- view ccBuiltinsInfo
pure $ Rewrite.userRewrite binfo t)
, Pass "commuteFnWithConst" (onOption coDoSimplifiercommuteFnWithConst) (pure . CommuteFnWithConst.commuteFnWithConst)
pure $ RewriteRules.userRewrite binfo t)
]

-- | Actual simplifier
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module PlutusIR.Transform.Rewrite where
module PlutusIR.Transform.RewriteRules
( module Export
, userRewrite
, defaultUniRewriteRules
) where

import PlutusIR.Transform.RewriteRules.CommuteFnWithConst as Export
import PlutusIR.Transform.RewriteRules.DecodeEncodeUtf8 as Export

import PlutusCore.Default
import PlutusIR.Analysis.Builtins
import PlutusIR
import Data.Monoid

import Data.Monoid
import Control.Lens

userRewrite :: (Semigroup a, t ~ Term tyname name uni fun a)
Expand All @@ -21,16 +26,8 @@ userRewrite bi t =
defaultUniRewriteRules :: RewriteRules DefaultUni DefaultFun
defaultUniRewriteRules = RewriteRules $ combineRules
-- rules are applied from left to right because of Dual
[ decodeEncode
[ decodeEncodeUtf8
, commuteFnWithConst
]
where
combineRules = foldMap (Dual . Endo)

decodeEncode :: Semigroup a => Term tyname name uni DefaultFun a -> Term tyname name uni DefaultFun a
decodeEncode = \case
BA DecodeUtf8 a1 a2 (BA EncodeUtf8 a3 a4 t) ->
-- place the missed annotations inside the rewritten term
(<> a1 <> a2 <> a3 <> a4) <$> t
t -> t

pattern BA b a1 a2 t <- Apply a1 (Builtin a2 b) t
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeApplications #-}

{- | Commute such that constants are the first arguments. Consider:
Expand Down Expand Up @@ -28,63 +27,55 @@ might expect that `equalsInteger` is the one that will benefit the most.
Plutonomy only commutes `EqualsInteger` in their `commEquals`.
-}

module PlutusIR.Transform.CommuteFnWithConst
(commuteFnWithConst
, commuteDefaultFun)
where
module PlutusIR.Transform.RewriteRules.CommuteFnWithConst
( commuteFnWithConst
) where

import Control.Lens (over)
import Data.Typeable (Typeable, eqT)
import PlutusCore.Default
import PlutusIR.Core.Plated (termSubterms)
import PlutusIR.Core.Type (Term (Apply, Builtin, Constant))

isConstant :: Term tyname name uni fun a -> Bool
isConstant Constant{} = True
isConstant _ = False
isConstant = \case
Constant{} -> True
_ -> False

commuteDefaultFun ::
commuteFnWithConst ::
forall tyname name uni a.
Term tyname name uni DefaultFun a ->
Term tyname name uni DefaultFun a
commuteDefaultFun = over termSubterms commuteDefaultFun . localCommute
where
localCommute tm@(Apply ann (Apply ann1 (Builtin annB fun) x) y@Constant{})
| isCommutative fun && not (isConstant x) =
Apply ann (Apply ann1 (Builtin annB fun) y) x
| otherwise = tm
localCommute tm = tm

commuteFnWithConst :: forall tyname name uni fun a. Typeable fun =>
Term tyname name uni fun a -> Term tyname name uni fun a
commuteFnWithConst = case eqT @fun @DefaultFun of
Just Refl -> commuteDefaultFun
Nothing -> id
commuteFnWithConst = \case
Apply ann1 (Apply ann2 (Builtin ann3 fun) arg1) arg2
| isCommutative fun
, not (isConstant arg1)
, isConstant arg2
-> Apply ann1 (Apply ann2 (Builtin ann3 fun) arg2) arg1
t -> t

-- | Returns whether a `DefaultFun` is commutative. Not using
-- catchall to make sure that this function catches newly added `DefaultFun`.
isCommutative :: DefaultFun -> Bool
isCommutative = \case
AddInteger -> True
SubtractInteger -> False
MultiplyInteger -> True
EqualsInteger -> True
EqualsByteString -> True
EqualsString -> True
EqualsData -> True
-- verbose laid down, to revisit this function if a new builtin is added
SubtractInteger -> False
DivideInteger -> False
QuotientInteger -> False
RemainderInteger -> False
ModInteger -> False
EqualsInteger -> True
LessThanInteger -> False
LessThanEqualsInteger -> False
-- Bytestrings
AppendByteString -> False
ConsByteString -> False
SliceByteString -> False
LengthOfByteString -> False
IndexByteString -> False
EqualsByteString -> True
LessThanByteString -> False
LessThanEqualsByteString -> False
-- Cryptography and hashes
Sha2_256 -> False
Sha3_256 -> False
Blake2b_224 -> False
Expand All @@ -110,27 +101,19 @@ isCommutative = \case
Bls12_381_millerLoop -> False
Bls12_381_mulMlResult -> False
Bls12_381_finalVerify -> False
-- Strings
AppendString -> False
EqualsString -> True
EncodeUtf8 -> False
DecodeUtf8 -> False
-- Bool
IfThenElse -> False
-- Unit
ChooseUnit -> False
-- Tracing
Trace -> False
-- Pairs
FstPair -> False
SndPair -> False
-- Lists
ChooseList -> False
MkCons -> False
HeadList -> False
TailList -> False
NullList -> False
-- Data
ChooseData -> False
ConstrData -> False
MapData -> False
Expand All @@ -142,7 +125,6 @@ isCommutative = \case
UnListData -> False
UnIData -> False
UnBData -> False
EqualsData -> True
SerialiseData -> False
MkPairData -> False
MkNilData -> False
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}

{- | Commute such that constants are the first arguments. Consider:
(1) equalsInteger 1 x
(2) equalsInteger x 1
We have unary application, so these are two partial applications:
(1) (equalsInteger 1) x
(2) (equalsInteger x) 1
With (1), we can share the `equalsInteger 1` node, and it will be the same across any place where
we do this.
With (2), both the nodes here include x, which is a variable that will likely be different in other
invocations of `equalsInteger`. So the second one is harder to share, which is worse for CSE.
So commuting `equalsInteger` so that it has the constant first both a) makes various occurrences of
`equalsInteger` more likely to look similar, and b) gives us a maximally-shareable node for CSE.
This applies to any commutative builtin function that takes constants as arguments, although we
might expect that `equalsInteger` is the one that will benefit the most.
Plutonomy only commutes `EqualsInteger` in their `commEquals`.
-}

module PlutusIR.Transform.RewriteRules.DecodeEncodeUtf8
( decodeEncodeUtf8
) where

import PlutusCore.Default
import PlutusIR


decodeEncodeUtf8 :: Semigroup a => Term tyname name uni DefaultFun a -> Term tyname name uni DefaultFun a
decodeEncodeUtf8 = \case
BA DecodeUtf8 a1 a2 (BA EncodeUtf8 a3 a4 t) ->
-- place the missed annotations inside the rewritten term
(<> a1 <> a2 <> a3 <> a4) <$> t
t -> t

pattern BA b a1 a2 t <- Apply a1 (Builtin a2 b) t
6 changes: 3 additions & 3 deletions plutus-core/plutus-ir/test/PlutusIR/Scoping/Tests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import PlutusIR.Generators.AST
import PlutusIR.Mark
import PlutusIR.Transform.Beta
import PlutusIR.Transform.CaseReduce
import PlutusIR.Transform.CommuteFnWithConst
import PlutusIR.Transform.RewriteRules.CommuteFnWithConst
import PlutusIR.Transform.DeadCode
import PlutusIR.Transform.EvaluateBuiltins
import PlutusIR.Transform.Inline.Inline qualified as Inline
Expand All @@ -32,8 +32,8 @@ test_names = testGroup "names"
pure . beta
, T.test_scopingGood "case-of-known-constructor" genTerm T.BindingRemovalNotOk T.PrerenameYes $
pure . caseReduce
, T.test_scopingGood "'commuteDefaultFun'" genTerm T.BindingRemovalNotOk T.PrerenameYes $
pure . commuteDefaultFun
, T.test_scopingGood "commuteFnWithConst" genTerm T.BindingRemovalNotOk T.PrerenameYes $
pure . commuteFnWithConst
, -- We say that it's fine to remove bindings, because they never actually get removed,
-- because the scope checking machinery doesn't create unused bindings, every binding
-- gets referenced at some point at least once (usually very close to the binding site).
Expand Down

This file was deleted.

22 changes: 0 additions & 22 deletions plutus-core/plutus-ir/test/PlutusIR/Transform/Rewrite/Tests.hs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module PlutusIR.Transform.RewriteRules.Tests where

import PlutusIR.Parser
import PlutusIR.Test
import PlutusIR.Analysis.Builtins
import PlutusIR.Transform.RewriteRules as RewriteRules

import Data.Default.Class
import Control.Lens
import Test.Tasty
import Test.Tasty.Extras

test_RewriteRules :: TestTree
test_RewriteRules = runTestNestedIn ["plutus-ir/test/PlutusIR/Transform"] $
testNested "RewriteRules" $
fmap
(goldenPir (RewriteRules.userRewrite builtinsInfo) pTerm)
[ "decodeEncodeUtf8"
, "equalsInt" -- this tests that the function works on equalInteger
, "divideInt" -- this tests that the function excludes not commutative functions
, "multiplyInt" -- this tests that the function works on multiplyInteger
, "let" -- this tests that it works in the subterms
]
where
builtinsInfo = def & rewriteRules .~ RewriteRules.defaultUniRewriteRules

0 comments on commit 2b41633

Please sign in to comment.