Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate arbitrary lambdas in animation primitive #1169

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion primer/src/Primer/Eval/Prim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ data ApplyPrimFunDetail = ApplyPrimFunDetail

-- | If this node is a reducible application of a primitive, return the name of the primitive, the arguments, and
-- (a computation for building) the result.
tryPrimFun :: Map GVarName PrimDef -> Expr -> Maybe (GVarName, [Expr], forall m. MonadFresh ID m => m Expr)
tryPrimFun :: Map GVarName PrimDef -> Expr -> Maybe (GVarName, [Expr], forall m. MonadFresh ID m => (Expr -> m Expr) -> m Expr)
tryPrimFun primDefs expr
| -- Since no primitive functions are polymorphic, there is no need to unfoldAPP
(Var _ (GlobalVarRef name), args) <- bimap stripAnns (map stripAnns) $ unfoldApp expr
Expand Down
27 changes: 21 additions & 6 deletions primer/src/Primer/Eval/Redex.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE OverloadedRecordDot #-}
Expand Down Expand Up @@ -27,9 +28,9 @@ import Foreword

import Control.Monad.Fresh (MonadFresh)
import Control.Monad.Log (MonadLog, WithSeverity)
import Control.Monad.Trans.Maybe (MaybeT)
import Control.Monad.Trans.Maybe (MaybeT, runMaybeT)
import Data.Data (Data)
import Data.Generics.Uniplate.Data (children, descendM)
import Data.Generics.Uniplate.Data (children, descendM, transformM)
import Data.List (zip3)
import Data.Map qualified as M
import Data.Set qualified as S
Expand Down Expand Up @@ -433,14 +434,16 @@ data Redex
-- ^ The original redex (used for details)
}
| ApplyPrimFun
{ result :: forall m. MonadFresh ID m => m Expr
{ result :: forall m. MonadFresh ID m => (Expr -> m Expr) -> m Expr
-- ^ The result of the applied primitive function
, primFun :: GVarName
-- ^ The applied primitive function (used for details)
, args :: [Expr]
-- ^ The original arguments to @primFun@ (used for details)
, orig :: Expr
-- ^ The original redex (used for details)
, tydefs :: TypeDefMap
, globals :: DefMap
}

data RedexType
Expand Down Expand Up @@ -779,7 +782,7 @@ viewRedex opts tydefs globals dir = \case
$ hoistMaybe
$ tryPrimFun (M.mapMaybe defPrim globals) e
>>= \(primFun, args, result) ->
pure ApplyPrimFun{result, primFun, args, orig = e}
pure ApplyPrimFun{result, primFun, args, orig = e, tydefs, globals}
-- (Λa.t : ∀b.T) S ~> (letType a = S in t) : (letType b = S in T)
orig@(APP _ (Ann _ (LAM m a body) (TForall _ forallVar forallKind tgtTy)) argTy) ->
pure
Expand Down Expand Up @@ -1216,8 +1219,20 @@ runRedex opts = \case
-- We should replace this with a proper exception. See:
-- https://github.com/hackworthltd/primer/issues/148
| otherwise -> error "Internal Error: RenameBindingsCase found no applicable branches"
ApplyPrimFun{result, primFun, orig, args} -> do
expr' <- result
ApplyPrimFun{result, primFun, orig, args, tydefs, globals} -> do
-- TODO this can run forever - we haven't set a bound on number of steps
-- TODO `transformM` probably doesn't give us the right eval order - reuse existing machinery
expr' <- result $ fix $ \f -> transformM \e ->
maybe (pure e) (f . fst <=< runRedex opts)
=<< runMaybeT
( flip runReaderT mempty
$ viewRedex
(ViewRedexOptions True True False) -- TODO ?
tydefs
globals
Syn -- TODO ?
e
)
let details =
ApplyPrimFunDetail
{ before = orig
Expand Down
109 changes: 56 additions & 53 deletions primer/src/Primer/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.ByteString.Base64 qualified as B64
import Data.Data (Data)
import Data.Map qualified as M
import Data.Set (isSubsetOf)
import Data.Set qualified as Set
import Diagrams.Backend.Rasterific (
Options (RasterificOptions),
Rasterific (Rasterific),
)
import Diagrams.Prelude (
Diagram,
V2 (..),
blue,
circle,
deg,
fillColor,
Expand All @@ -58,6 +61,7 @@ import Diagrams.Prelude (
renderDia,
rotate,
sRGB24,
text,
translate,
(@@),
)
Expand All @@ -79,9 +83,9 @@ import Primer.Core (
GVarName,
GlobalName,
ID,
LocalName (unLocalName),
ModuleName,
PrimCon (PrimAnimation, PrimChar, PrimInt),
TmVarRef (LocalVarRef),
TyConName,
Type' (..),
ValConName,
Expand All @@ -95,7 +99,7 @@ import Primer.Core.DSL (
prim,
tcon,
)
import Primer.Core.Utils (generateIDs)
import Primer.Core.Utils (freeVars, generateIDs)
import Primer.JSON (CustomJSON (..), PrimerJSON)
import Primer.Name (Name)
import Primer.Primitives.PrimDef (PrimDef (..))
Expand Down Expand Up @@ -242,48 +246,49 @@ primFunTypes = \case
a = TApp ()
f = TFun ()

primFunDef :: PrimDef -> [Expr' () () ()] -> Either PrimFunError (forall m. MonadFresh ID m => m Expr)
primFunDef :: PrimDef -> [Expr' () () ()] -> Either PrimFunError (forall m. MonadFresh ID m => (Expr -> m Expr) -> m Expr)
primFunDef def args = case def of
ToUpper -> case args of
[PrimCon _ (PrimChar c)] ->
Right $ char $ toUpper c
Right $ const $ char $ toUpper c
_ -> err
IsSpace -> case args of
[PrimCon _ (PrimChar c)] ->
Right $ boolAnn (isSpace c)
Right $ const $ boolAnn (isSpace c)
_ -> err
HexToNat -> case args of
[PrimCon _ (PrimChar c)] -> Right $ maybeAnn (tcon tNat) nat (digitToIntSafe c)
[PrimCon _ (PrimChar c)] -> Right $ const $ maybeAnn (tcon tNat) nat (digitToIntSafe c)
where
digitToIntSafe :: Char -> Maybe Natural
digitToIntSafe c' = fromIntegral <$> (guard (isHexDigit c') $> digitToInt c')
_ -> err
NatToHex -> case args of
[exprToNat -> Just n] ->
Right $ maybeAnn (tcon tChar) char $ intToDigitSafe n
Right $ const $ maybeAnn (tcon tChar) char $ intToDigitSafe n
where
intToDigitSafe :: Natural -> Maybe Char
intToDigitSafe n' = guard (0 <= n && n <= 15) $> intToDigit (fromIntegral n')
_ -> err
EqChar -> case args of
[PrimCon _ (PrimChar c1), PrimCon _ (PrimChar c2)] ->
Right $ boolAnn $ c1 == c2
Right $ const $ boolAnn $ c1 == c2
_ -> err
IntAdd -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ int $ x + y
Right $ const $ int $ x + y
_ -> err
IntMinus -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ int $ x - y
Right $ const $ int $ x - y
_ -> err
IntMul -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ int $ x * y
Right $ const $ int $ x * y
_ -> err
IntQuotient -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right
$ const
$ maybeAnn (tcon tInt) int
$ if y == 0
then Nothing
Expand All @@ -292,6 +297,7 @@ primFunDef def args = case def of
IntRemainder -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right
$ const
$ maybeAnn (tcon tInt) int
$ if y == 0
then Nothing
Expand All @@ -300,91 +306,88 @@ primFunDef def args = case def of
IntQuot -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right
$ const
$ int
$ if y == 0 then 0 else x `div` y
_ -> err
IntRem -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right
$ const
$ int
$ if y == 0
then x
else x `mod` y
_ -> err
IntLT -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ boolAnn $ x < y
Right $ const $ boolAnn $ x < y
_ -> err
IntLTE -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ boolAnn $ x <= y
Right $ const $ boolAnn $ x <= y
_ -> err
IntGT -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ boolAnn $ x > y
Right $ const $ boolAnn $ x > y
_ -> err
IntGTE -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ boolAnn $ x >= y
Right $ const $ boolAnn $ x >= y
_ -> err
IntEq -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ boolAnn $ x == y
Right $ const $ boolAnn $ x == y
_ -> err
IntNeq -> case args of
[PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] ->
Right $ boolAnn $ x /= y
Right $ const $ boolAnn $ x /= y
_ -> err
IntToNat -> case args of
[PrimCon _ (PrimInt x)] ->
Right
$ const
$ maybeAnn (tcon tNat) nat
$ if x < 0
then Nothing
else Just $ fromInteger x
_ -> err
IntFromNat -> case args of
[exprToNat -> Just n] ->
Right $ int $ fromIntegral n
Right $ const $ int $ fromIntegral n
_ -> err
Animate -> case args of
-- Since we only support translating a `Picture` expression to an image once it is in normal form,
-- this guard will only pass when `picture` has no free variables other than `time`.
[PrimCon () (PrimInt duration), Lam () time picture]
| Just (frames :: [Diagram Rasterific]) <- traverse diagramAtTime [0 .. (duration * 100) `div` frameLength - 1] ->
Right
$ prim
$ PrimAnimation
$ either
-- This case really shouldn't be able to happen, unless `diagrams-rasterific` is broken.
-- In fact, the default behaviour (`animatedGif`) is just to write the error to `stdout`,
-- and we only have to handle this because we need to use the lower-level `rasterGif`,
-- for unrelated reasons (getting the `Bytestring` without dumping it to a file).
mempty
(decodeUtf8 . B64.encode . toS)
$ encodeComplexGifImage
$ GifEncode (fromInteger width) (fromInteger height) Nothing Nothing gifLooping
$ flip palettizeWithAlpha DisposalRestoreBackground
$ map
( (fromInteger frameLength,)
. renderDia
Rasterific
(RasterificOptions (mkSizeSpec $ Just . fromInteger <$> V2 width height))
. rectEnvelope
(fromInteger <$> mkP2 (-width `div` 2) (-height `div` 2))
(fromInteger <$> V2 width height)
)
frames
[PrimCon () (PrimInt duration), Lam () time picture] | freeVars picture `isSubsetOf` Set.singleton (unLocalName time) -> Right \eval -> do
frames0 <- for [0 .. (duration * 100) `div` frameLength - 1] \t ->
-- TODO let the evaluator do the beta reduction as well?
fmap exprToDiagram . eval =<< generateIDs (Let () time (PrimCon () (PrimInt t)) picture)
-- TODO better error handling
let (frames :: [Diagram Rasterific]) = fromMaybe [text "error" <> (circle 40 & fillColor blue)] $ sequence frames0
prim
$ PrimAnimation
$ either
-- This case really shouldn't be able to happen, unless `diagrams-rasterific` is broken.
-- In fact, the default behaviour (`animatedGif`) is just to write the error to `stdout`,
-- and we only have to handle this because we need to use the lower-level `rasterGif`,
-- for unrelated reasons (getting the `Bytestring` without dumping it to a file).
mempty
(decodeUtf8 . B64.encode . toS)
$ encodeComplexGifImage
$ GifEncode (fromInteger width) (fromInteger height) Nothing Nothing gifLooping
$ flip palettizeWithAlpha DisposalRestoreBackground
$ map
( (fromInteger frameLength,)
. renderDia
Rasterific
(RasterificOptions (mkSizeSpec $ Just . fromInteger <$> V2 width height))
. rectEnvelope
(fromInteger <$> mkP2 (-width `div` 2) (-height `div` 2))
(fromInteger <$> V2 width height)
)
frames
where
-- Note that this simple substitution hack only allows for trivial functions,
-- i.e. those where only substitution is needed for the function body to reach a normal form.
-- Our primitives system doesn't yet support further evaluation here.
diagramAtTime t = exprToDiagram $ substTime (PrimCon () (PrimInt t)) picture
where
substTime a = \case
Var () (LocalVarRef t') | t' == time -> a
Con () c es -> Con () c $ map (substTime a) es
e -> e
-- Values which are hardcoded, for now at least, for the sake of keeping the student-facing API simple.
-- We keep the frame rate and resolution low to avoid serialising huge GIFs.
gifLooping = LoopingForever
Expand All @@ -394,7 +397,7 @@ primFunDef def args = case def of
_ -> err
PrimConst -> case args of
[x, _] ->
Right $ generateIDs x `ann` tcon tBool
Right $ const $ generateIDs x `ann` tcon tBool
_ -> err
where
exprToNat = \case
Expand Down
30 changes: 18 additions & 12 deletions primer/src/Primer/Typecheck.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedLabels #-}

-- | Typechecking for Core expressions.
Expand Down Expand Up @@ -744,20 +745,25 @@ check t = \case
scrutWrap <- Hole <$> meta' (TCSynthed (TEmptyHole ())) <*> pure (addChkMetaT (TEmptyHole ()) e')
pure $ Case caseMeta scrutWrap [] CaseExhaustive
Left (TDIPrim tc) -> do
unless (tc == tInt || tc == tChar || tc == tAnimation) $ throwError' $ InternalError $ "Unknown primitive type: " <> show tc
let f b = case caseBranchName b of
PatCon _ -> Nothing
PatPrim pc -> case pc of
PrimInt p | tc == tInt -> Just $ Left (p, b)
PrimChar p | tc == tChar -> Just $ Right (p, b)
_ -> Nothing
-- all branches right sort & order
sh <- asks smartHoles
brs' <- case partitionEithers <$> traverse f brs of
Just ([], chs) | isSorted (fst <$> chs) -> pure $ snd <$> chs
Just (is, []) | isSorted (fst <$> is) -> pure $ snd <$> is
_ | NoSmartHoles <- sh -> throwError' $ WrongCaseBranches tc (caseBranchName <$> brs) (fb /= CaseExhaustive)
_ | SmartHoles <- sh -> pure []
consistentBranches <-
if
| tc == tInt -> pure $ maybe False isSorted $ for (map caseBranchName brs) $ \case
PatPrim (PrimInt p) -> pure p
_ -> Nothing
| tc == tChar -> pure $ maybe False isSorted $ for (map caseBranchName brs) $ \case
PatPrim (PrimChar p) -> pure p
_ -> Nothing
-- some primitives do not admit any sensible notion of pattern matching
| tc == tAnimation -> pure $ null brs
| otherwise -> throwError' $ InternalError $ "Unknown primitive type: " <> show tc
brs' <-
if consistentBranches
then pure brs
else case sh of
NoSmartHoles -> throwError' $ WrongCaseBranches tc (caseBranchName <$> brs) (fb /= CaseExhaustive)
SmartHoles -> pure []
-- no params, check the rhs
brs'' <- for brs' $ \(CaseBranch c ps rhs) -> do
case (ps, sh) of
Expand Down
2 changes: 1 addition & 1 deletion primer/test/Tests/EvalFull.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1972,7 +1972,7 @@ test_animation =
cTranslate
[ int 35
, int 0
, con1 cCircle $ lvar "t"
, con1 cCircle $ pfun IntMul `app` lvar "t" `app` int 2
]
]
]
Expand Down
Binary file modified primer/test/outputs/eval/animation/2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.