Skip to content

Commit

Permalink
added primitive support
Browse files Browse the repository at this point in the history
  • Loading branch information
emekoi committed Jan 10, 2024
1 parent ce6ec10 commit 5fb65f8
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 40 deletions.
2 changes: 1 addition & 1 deletion drafts/pattern-matching/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ instance P.Pretty Type where
pretty = go False
where
go _ (TyMeta i _) = "?" <> P.pretty i
go _ (TyData (TypeId c)) = P.pretty c
go _ (TyData (TypeId c)) = case c.display of TrivialEq c -> P.pretty c
go False (TyFun args ret) =
P.concatWith (\x y -> x P.<+> "->" P.<+> y) (go True <$> args)
P.<+> "->" P.<+> go False ret
Expand Down
65 changes: 58 additions & 7 deletions drafts/pattern-matching/Elab.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ data ElabState = ElabState
, typeInfo :: IORef (IntMap TypeInfo)
, dataCons :: IORef (Map Text DataId)
, dataConInfo :: IORef (IntMap DataInfo)
, primitives :: IORef (Map Text Type)
, globalTerms :: IORef (Map Text Var)
, localTerms :: Map Text Var
, termLevel :: Int
Expand Down Expand Up @@ -98,6 +99,7 @@ runElab fp (Elab s) = do
<*> newIORef mempty
<*> newIORef mempty
<*> newIORef mempty
<*> newIORef mempty
<*> pure mempty
<*> pure 0
<*> pure 0
Expand Down Expand Up @@ -633,8 +635,45 @@ tyInferExpr (P.EString _ s) k = do
TLetV x (VString s) <$> apply k x
tyInferExpr (P.EVar _ v) k = do
lookupVar v >>= apply k
tyInferExpr (P.EPrim {}) _ = error "type error"
tyInferExpr (P.EApp _ (P.EPrim _ _) _) _ = error "TODO: prims"
-- NOTE: i think disallowing bare primitives is fine, since they are not
-- true functions.
tyInferExpr (P.EPrim r p) k = do
ElabState{primitives} <- ask
t <- readIORef primitives >>= Map.lookup (P.getName p) >>> \case
Nothing -> do
throwElab
[ "unknown primitive", errQuote ("#" <> P.getName p) ]
[ (range p, This "used here") ]
Just t -> do
tyZonk t >>= \case
t@(TyFun {}) -> throwElab
["cannot use primitive of type", errPretty t, "as a value"]
[(r, This "use primitive as value")]
t -> pure t
x <- freshLocal "p" t
TLetP x (PrimOp (P.getName p)) [] <$> apply k x
tyInferExpr (P.EApp r (P.EPrim _ p) xs) k = do
ElabState{primitives} <- ask
(args, ret) <- readIORef primitives >>= Map.lookup (P.getName p) >>> \case
Nothing -> do
throwElab
[ "unknown primitive", errQuote ("#" <> P.getName p) ]
[ (range p, This "used here") ]
Just t -> do
tyZonk t >>= \case
TyFun args ret -> pure (args, ret)
t -> throwElab
["cannot apply primitive of type", errPretty t]
[(r, This "attempted application")]
let
arity = length args
arity' = length xs
unless (arity == arity') $ throwElab
[ "primitive", errQuote ("#" <> P.getName p), "has arity", errPretty arity ]
[ (range p, This ["applied to", errPretty arity', "arguments"]) ]
tyCheckExprAll xs args \xs -> do
x <- freshLocal "p" ret
TLetP x (PrimOp (P.getName p)) xs <$> apply k x
tyInferExpr (P.EData r c) k = do
cid <- findDataId c
cinfo <- findDataInfo cid
Expand Down Expand Up @@ -685,20 +724,20 @@ tyInferExpr e@(P.EApp r f xs) k =
kx <- k v
k <- freshLabel "k" [ret]
pure $ TLetK k [v] kx (TApp f k xs)
tyInferExpr (P.ELet _ (P.ExprDecl _ (P.Name _ x) Empty e1) e2) k = do
tyInferExpr (P.ELet _ (P.ExprDecl _ (P.Name _ x) t Empty e1) e2) k = do
-- 1. create the join point j
-- 2. lower e2 to the body of j and create a meta-continuation that
-- creates a letk binding and for j and
-- 3. lower e1 and check that it has the type that e2 expects
-- 4. create the letk binding by applying k
t <- freshMeta
t <- maybe freshMeta tyCheckType t
j <- freshLabel "j" [t]
k <- bindVar t \v -> do
local (\ctx -> ctx
{ localTerms = Map.insert x v ctx.localTerms
}) $ TLetK j [v] <$> tyInferExpr e2 k
k <$> tyCheckExpr e1 t (Abstract j)
tyInferExpr (P.ELet _ (P.ExprDecl _ _f _ps _e1) _e2) _k = do
tyInferExpr (P.ELet _ (P.ExprDecl _ _f _t _ps _e1) _e2) _k = do
error "TODO"
tyInferExpr (P.EMatch _ e alts) k = do
tyInferExpr e $ Wrap \x -> do
Expand Down Expand Up @@ -729,25 +768,37 @@ tyInferExpr (P.EMatch _ e alts) k = do
} :<| rows)

tyInferExprDecl :: HasCallStack => P.ExprDecl Range -> Elab Decl
tyInferExprDecl (P.ExprDecl _r (P.Name _ f) ps e) = do
tyInferExprDecl (P.ExprDecl _r (P.Name _ f) t ps e) = do
ts <- traverse (const freshMeta) ps
bindPatterns ps ts \_ps _vs -> do
liftIO $ putStr "\t" >> print _ps
liftIO $ putStr "\t" >> print _vs
t <- freshMeta
t <- maybe freshMeta tyCheckType t
k <- freshLabel "k" [t]
f <- freshName (\x i -> Label (Global x i) ts) f
-- TODO: how do i know if this is a value? probably some check of triviality
DTerm f k <$> tyCheckExpr e t (Abstract k)

elaborate :: HasCallStack => P.Module -> Elab ()
elaborate (P.Module ds) = do
-- get the names of all data types
forM_ ds \case
P.DData _ m c _ -> do
makeTypeId m c $> ()
_ -> pure ()
ElabState{primitives} <- ask
exprs <- (`witherM` ds) \case
P.DData _ m c xs -> tyCheckDataDecl m c xs $> Nothing
P.DPrim _ (P.Name r n) t -> do
readIORef primitives >>= Map.lookup n >>> \case
Just _ -> do
throwElab
[ "redefinition of primitive", errQuote ("#" <> n) ]
[ (r, This "redefined here") ]
Nothing -> do
t <- tyCheckType t
modifyIORef primitives (Map.insert n t)
pure Nothing
P.DExpr e -> pure (Just e)

-- mapM_ tyCheckExprDecl exprs
Expand Down
13 changes: 6 additions & 7 deletions drafts/pattern-matching/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Error
( module Error
, IsString (..)
, Marker (..)
, Position (..)
, Position (Position)
, Report (..)
) where

Expand Down Expand Up @@ -36,10 +36,9 @@ instance IsList ErrMsg where
toList e = [e]

instance P.Pretty ErrMsg where
pretty (ErrText txt) = P.pretty txt
pretty (ErrList xs) = P.hsep $
P.pretty <$> xs
pretty (ErrShow s) = P.viaShow s
pretty (ErrText txt) = P.pretty txt
pretty (ErrList xs) = P.hsep $ P.pretty <$> xs
pretty (ErrShow s) = P.viaShow s
pretty (ErrPretty f p) = f $ P.pretty p

newtype Error
Expand All @@ -54,7 +53,7 @@ pattern Err' :: msg -> [(Position, Marker msg)] -> Report msg
pattern Err' msg xs = Err Nothing msg xs []

throwError :: ErrMsg -> [(Position, Marker ErrMsg)] -> error
throwError msg = throw . Error . (:[]) . Err' msg
throwError msg = throw . Error . (: []) . Err' msg

throwError' :: (MonadIO m) => ErrMsg -> [(Position, Marker ErrMsg)] -> m error
throwError' msg = liftIO . throwIO . Error . (:[]) . Err' msg
throwError' msg = liftIO . throwIO . Error . (: []) . Err' msg
22 changes: 15 additions & 7 deletions drafts/pattern-matching/Lexer.x
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@ tokens :-

-- literals
<0> "-"? $digit+ { tokInteger }
<0> \"[^\"]*\" { tokByteString String }
<0> \" { begin string }
<string> \" { begin 0 }
<string> [^\"]* { tokByteString String }

-- keywords
<0> data { tok Data }
<0> in { tok In }
<0> let { tok Let }
<0> match { tok Match }
<0> data { tok Data }
<0> in { tok In }
<0> let { tok Let }
<0> match { tok Match }
<0> primitive { tok Primitive }

-- operators and symbols
<0> "{" { tok LBracket }
Expand All @@ -65,13 +68,16 @@ tokens :-
<0> "_" { tok Underscore }
<0> "=" { tok Eq }
<0> ":" { tok Colon }
<0> "#" { tok Hash }

-- variables and constructors
<0> @variable { tokByteString Variable }
<0> @constructor { tokByteString Constructor }

{ {-# LINE 73 "Lexer.x" #-}
-- hash and magic
<0> "#" { begin magic }
<magic> [^$white]* { tokByteString (\x -> if BS.null x then Hash else Magic x) `andBegin` 0 }

{ {-# LINE 81 "Lexer.x" #-}
-- -----------------------------------------------------------------------------
-- The input type
type AlexInput = (AlexPosn, -- current position,
Expand Down Expand Up @@ -287,6 +293,7 @@ data TokenClass
= EOF
| Constructor ByteString
| Variable ByteString
| Magic ByteString
-- literals
| String ByteString
| Int Integer
Expand All @@ -295,6 +302,7 @@ data TokenClass
| In
| Let
| Match
| Primitive
-- operators and symbols
| LBracket
| RBracket
Expand Down
4 changes: 2 additions & 2 deletions drafts/pattern-matching/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import System.IO qualified as System
main :: IO ()
main = do
(input, file) <- System.getArgs >>= \case
[] -> (, "<stdin>") <$> BS.getContents
file : _ -> (, file) <$> BS.readFile file
[] -> (,"<stdin>") <$> BS.getContents
file : _ -> (,file) <$> BS.readFile file
let diagFile = addFile mempty file (BS.unpack input)
handle (handleErr diagFile) do
let ds = runAlex' file input parse
Expand Down
48 changes: 32 additions & 16 deletions drafts/pattern-matching/Parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ import Prettyprinter qualified as P
%token
variableT { L.Token _ (L.Variable _) }
constructorT { L.Token _ (L.Constructor _) }
magicT { L.Token _ (L.Magic _) }
intT { L.Token _ (L.Int _) }
stringT { L.Token _ (L.String _) }
dataT { L.Token _ L.Data }
inT { L.Token _ L.In }
letT { L.Token _ L.Let }
matchT { L.Token _ L.Match }
primitiveT { L.Token _ L.Primitive }
'{' { L.Token _ L.LBracket }
'}' { L.Token _ L.RBracket }
'(' { L.Token _ L.LParen }
Expand Down Expand Up @@ -103,23 +105,33 @@ variable :: { Name L.Range }
constructor :: { Name L.Range }
: constructorT { unToken $1 \r (L.Constructor name) -> Name r (decodeUtf8 name) }

magic :: { Name L.Range }
: magicT { unToken $1 \r (L.Magic name) -> Name r (decodeUtf8 name) }

tyAtom :: { Type L.Range }
: '(' type ')' { $2 }
| constructor { TyData (range $1) $1 }

type :: { Type L.Range }
: tyAtom { $1 }
| sepBy1(tyAtom, '->') tyAtom { TyFun (rangeSeq' $1 $2) $1 $2 }
: tyAtom { $1 }
| type '->' tyAtom {
case $1 of
TyFun r args ret -> TyFun (r <-> $3) (args :|> ret) $3
_ -> TyFun ($1 <-> $3) (Seq.singleton $1) $3
}

annot :: { Type L.Range }
: ':' type { $2 }

pattern_1 :: { Pattern L.Range }
: '(' pattern_3 ')' { $2 }
| '(' pattern_3 ':' type ')' { PType ($1 <-> $5) $2 $4 }
| '_' { unToken $1 \r _ -> PWild r }
| intT { unToken $1 \r (L.Int i) -> PInt r i }
| stringT { unToken $1 \r (L.String s) -> PString r (decodeUtf8 s) }
| variable { let (Name r _) = $1 in PAs r $1 (PWild r) }
| constructor { PData (range $1) $1 Empty }
| variable '@' pattern_1 { PAs ($1 <-> $3) $1 $3 }
: '(' pattern_3 ')' { $2 }
| '(' pattern_3 annot ')' { PType ($1 <-> $4) $2 $3 }
| '_' { unToken $1 \r _ -> PWild r }
| intT { unToken $1 \r (L.Int i) -> PInt r i }
| stringT { unToken $1 \r (L.String s) -> PString r (decodeUtf8 s) }
| variable { let (Name r _) = $1 in PAs r $1 (PWild r) }
| constructor { PData (range $1) $1 Empty }
| variable '@' pattern_1 { PAs ($1 <-> $3) $1 $3 }

pattern_2 :: { Pattern L.Range }
: pattern_1 { $1 }
Expand All @@ -139,7 +151,7 @@ alt :: { Alt L.Range }

atom :: { Expr L.Range }
: variable { EVar (range $1) $1 }
| variable '#' { EPrim ($1 <-> $2) $1 }
| magic { EPrim (range $1) $1 }
| intT { unToken $1 \r (L.Int i) -> EInt r i }
| stringT { unToken $1 \r (L.String s) -> EString r (decodeUtf8 s) }
| constructor { EData (range $1) $1 }
Expand All @@ -157,17 +169,18 @@ dataCon :: { DataCon L.Range }
: constructor many(tyAtom) { DataCon (rangeSeq $1 $2) $1 $2 }
edecl :: { ExprDecl L.Range }
: variable many(pattern_1) '=' expr { ExprDecl ($1 <-> $4) $1 $2 $4 }
: variable many(pattern_1) optional(annot) '=' expr { ExprDecl ($1 <-> $5) $1 $3 $2 $5 }
decl :: { Decl L.Range }
: letT edecl { DExpr $2 }
| primitiveT magic ':' type { DPrim ($1 <-> $4) $2 $4 }
| dataT optionalB('#') constructor { DData ($1 <-> $3) $2 $3 Empty }
| dataT optionalB('#') constructor '=' sepBy(dataCon, '|') { DData (rangeSeq $1 $5) $2 $3 $5 }
decls :: { Module }
: many(decl) { Module $1 }
{ {-# LINE 171 "Parser.y" #-}
{ {-# LINE 184 "Parser.y" #-}
decodeUtf8 :: ByteString -> Text
decodeUtf8 = LT.toStrict . LE.decodeUtf8
Expand Down Expand Up @@ -293,15 +306,16 @@ instance P.Pretty (Alt a) where
<+> P.pretty e

data ExprDecl a
= ExprDecl a (Name a) (Seq (Pattern a)) (Expr a)
= ExprDecl a (Name a) (Maybe (Type a)) (Seq (Pattern a)) (Expr a)
deriving (Foldable, Show, Functor)

instance HasInfo (ExprDecl i) i

instance P.Pretty (ExprDecl a) where
pretty (ExprDecl _ n xs e) =
"let" <+> P.concatWith (<+>) (P.pretty n :<| (wrap <$> xs)) <+> "=" <+> P.pretty e
pretty (ExprDecl _ n t xs e) =
"let" <+> P.concatWith (<+>) (P.pretty n :<| (wrap <$> xs)) <+> eqT <+> P.pretty e
where
eqT = case t of Nothing -> "="; Just t -> ":" <+> P.pretty t <+> "="
wrap p@(PData _ _ ps) | not (null ps) = P.parens $ P.pretty p
wrap p@(POr {}) = P.parens $ P.pretty p
wrap p = P.pretty p
Expand Down Expand Up @@ -348,13 +362,15 @@ instance P.Pretty (DataCon a) where
data Decl a
= DExpr (ExprDecl a)
| DPrim a (Name a) (Type a)
| DData a Bool (Name a) (Seq (DataCon a))
deriving (Foldable, Show, Functor)
instance HasInfo (Decl i) i
instance P.Pretty (Decl a) where
pretty (DExpr e) = P.pretty e
pretty (DPrim _ p t) = "primitive" <+> "#" <> P.pretty p <+> ":" <+> P.pretty t
pretty (DData _ m c (fmap P.pretty -> xs)) =
if null xs then d <+> P.pretty c else
d <+> P.pretty c <+> "="
Expand Down

0 comments on commit 5fb65f8

Please sign in to comment.