Skip to content

Commit

Permalink
Add deriving for Eq (#3176)
Browse files Browse the repository at this point in the history
This pr adds automatic implementation of `Eq` instances for inductive
types. To create such an instance the user will use the same syntax as a
regular instance with two differences.
1. It is prefixed with the `deriving` keyword.
2. It has no body.

E.g. 
```
deriving instance
eqProductI {A B} {{Eq A}} {{Eq B}} : Eq (Pair A B);
```

This desugars into an instance that returns true when the constructors
match and all arguments are equal according to their respective
instances. There is no special handling of type errors occurring in the
generated code. I.e. if the user forgets a necessary instance argument
in the type signature, a type error will occur in the generated code.

## Stdlib PR
- anoma/juvix-stdlib#148

# Future work
* In the future we should look at
https://www.dreixel.net/research/pdf/gdmh_nocolor.pdf
* See also:
https://gitlab.haskell.org/ghc/ghc/-/wikis/commentary/compiler/generic-deriving

---------

Co-authored-by: Lukasz Czajka <lukasz@heliax.dev>
  • Loading branch information
janmasrovira and lukaszcz authored Nov 22, 2024
1 parent 19ecfa9 commit c100812
Show file tree
Hide file tree
Showing 41 changed files with 905 additions and 281 deletions.
13 changes: 8 additions & 5 deletions src/Juvix/Compiler/Backend/Html/Translation/FromTyped.hs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ goStatement = \case
StatementInductive t -> goInductive t
StatementOpenModule t -> goOpen t
StatementFunctionDef t -> goFunctionDef t
StatementDeriving t -> goDeriving t
StatementSyntax syn -> goSyntax syn
StatementImport t -> goImport t
StatementModule m -> goLocalModule m
Expand Down Expand Up @@ -537,13 +538,15 @@ goAxiom axiom = do
axiomHeader :: Sem r Html
axiomHeader = ppCodeHtml defaultOptions (set axiomDoc Nothing axiom)

goDeriving :: forall r. (Members '[Reader HtmlOptions] r) => Deriving 'Scoped -> Sem r Html
goDeriving def = do
sig <- ppHelper (ppCode def)
defHeader (def ^. derivingFunLhs . funLhsName) sig Nothing

goFunctionDef :: forall r. (Members '[Reader HtmlOptions] r) => FunctionDef 'Scoped -> Sem r Html
goFunctionDef def = do
sig' <- funSig
defHeader (def ^. signName) sig' (def ^. signDoc)
where
funSig :: Sem r Html
funSig = ppHelper (ppCode (functionDefLhs def))
sig <- ppHelper (ppCode (functionDefLhs def))
defHeader (def ^. signName) sig (def ^. signDoc)

goInductive :: forall r. (Members '[Reader HtmlOptions] r) => InductiveDef 'Scoped -> Sem r Html
goInductive def = do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,13 @@ indModuleFilter :: forall s. [Concrete.Statement s] -> [Concrete.Statement s]
indModuleFilter =
filter
( \case
Concrete.StatementSyntax _ -> True
Concrete.StatementFunctionDef _ -> True
Concrete.StatementImport _ -> True
Concrete.StatementInductive _ -> True
Concrete.StatementModule o -> o ^. Concrete.moduleOrigin == LocalModuleSource
Concrete.StatementOpenModule _ -> True
Concrete.StatementAxiom _ -> True
Concrete.StatementProjectionDef _ -> True
Concrete.StatementSyntax {} -> True
Concrete.StatementFunctionDef {} -> True
Concrete.StatementDeriving {} -> True
Concrete.StatementImport {} -> True
Concrete.StatementInductive {} -> True
Concrete.StatementOpenModule {} -> True
Concrete.StatementAxiom {} -> True
Concrete.StatementProjectionDef {} -> True
)
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Builtins.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Builtins
( module Juvix.Compiler.Builtins.Nat,
module Juvix.Compiler.Builtins.IO,
module Juvix.Compiler.Builtins.Eq,
module Juvix.Compiler.Builtins.Int,
module Juvix.Compiler.Builtins.Bool,
module Juvix.Compiler.Builtins.List,
Expand All @@ -24,6 +25,7 @@ import Juvix.Compiler.Builtins.ByteArray
import Juvix.Compiler.Builtins.Cairo
import Juvix.Compiler.Builtins.Control
import Juvix.Compiler.Builtins.Debug
import Juvix.Compiler.Builtins.Eq
import Juvix.Compiler.Builtins.Field
import Juvix.Compiler.Builtins.IO
import Juvix.Compiler.Builtins.Int
Expand Down
25 changes: 25 additions & 0 deletions src/Juvix/Compiler/Builtins/Eq.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Juvix.Compiler.Builtins.Eq where

import Juvix.Compiler.Internal.Builtins
import Juvix.Compiler.Internal.Extra
import Juvix.Prelude
import Juvix.Prelude.Pretty

checkEqDef :: forall r. (Members '[Reader BuiltinsTable, Error ScoperError] r) => InductiveDef -> Sem r ()
checkEqDef d = do
let err :: forall a. Text -> Sem r a
err = builtinsErrorText (getLoc d)
let eqTxt = prettyText BuiltinEq
unless (isSmallUniverse' (d ^. inductiveType)) (err (eqTxt <> " should be in the small universe"))
case d ^. inductiveParameters of
[_] -> return ()
_ -> err (eqTxt <> " should have exactly one type parameter")
case d ^. inductiveConstructors of
[c1] -> checkMkEq c1
_ -> err (eqTxt <> " should have exactly two constructors")

checkMkEq :: ConstructorDef -> Sem r ()
checkMkEq _ = return ()

checkIsEq :: FunctionDef -> Sem r ()
checkIsEq _ = return ()
13 changes: 13 additions & 0 deletions src/Juvix/Compiler/Concrete/Data/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ builtinConstructors = \case
BuiltinEcPoint -> [BuiltinMkEcPoint]
BuiltinAnomaResource -> [BuiltinMkAnomaResource]
BuiltinAnomaAction -> [BuiltinMkAnomaAction]
BuiltinEq -> [BuiltinMkEq]

data BuiltinInductive
= BuiltinNat
Expand All @@ -67,6 +68,7 @@ data BuiltinInductive
| BuiltinList
| BuiltinMaybe
| BuiltinPair
| BuiltinEq
| BuiltinPoseidonState
| BuiltinEcPoint
| BuiltinAnomaResource
Expand All @@ -87,6 +89,7 @@ instance Pretty BuiltinInductive where
BuiltinList -> Str.list
BuiltinMaybe -> Str.maybe_
BuiltinPair -> Str.pair
BuiltinEq -> Str.eq
BuiltinPoseidonState -> Str.cairoPoseidonState
BuiltinEcPoint -> Str.cairoEcPoint
BuiltinAnomaResource -> Str.anomaResource
Expand All @@ -109,6 +112,7 @@ instance Pretty BuiltinConstructor where
BuiltinMkEcPoint -> Str.cairoMkEcPoint
BuiltinMkAnomaResource -> Str.anomaMkResource
BuiltinMkAnomaAction -> Str.anomaMkAction
BuiltinMkEq -> Str.mkEq

data BuiltinConstructor
= BuiltinNatZero
Expand All @@ -119,6 +123,7 @@ data BuiltinConstructor
| BuiltinIntNegSuc
| BuiltinListNil
| BuiltinListCons
| BuiltinMkEq
| BuiltinMaybeNothing
| BuiltinMaybeJust
| BuiltinPairConstr
Expand Down Expand Up @@ -161,6 +166,7 @@ data BuiltinFunction
| BuiltinIntLe
| BuiltinIntLt
| BuiltinFromNat
| BuiltinIsEqual
| BuiltinFromInt
| BuiltinSeq
| BuiltinMonadBind
Expand Down Expand Up @@ -202,6 +208,7 @@ instance Pretty BuiltinFunction where
BuiltinFromNat -> Str.fromNat
BuiltinFromInt -> Str.fromInt
BuiltinSeq -> Str.builtinSeq
BuiltinIsEqual -> Str.isEqual
BuiltinMonadBind -> Str.builtinMonadBind

data BuiltinAxiom
Expand Down Expand Up @@ -434,6 +441,7 @@ isNatBuiltin = \case
BuiltinNatLt -> True
BuiltinNatEq -> True
--
BuiltinIsEqual -> False
BuiltinAssert -> False
BuiltinBoolIf -> False
BuiltinBoolOr -> False
Expand Down Expand Up @@ -486,13 +494,15 @@ isIntBuiltin = \case
BuiltinFromNat -> False
BuiltinFromInt -> False
BuiltinSeq -> False
BuiltinIsEqual -> False
BuiltinMonadBind -> False

isCastBuiltin :: BuiltinFunction -> Bool
isCastBuiltin = \case
BuiltinFromNat -> True
BuiltinFromInt -> True
--
BuiltinIsEqual -> False
BuiltinAssert -> False
BuiltinIntEq -> False
BuiltinIntPlus -> False
Expand Down Expand Up @@ -532,6 +542,7 @@ isIgnoredBuiltin f
.&&. (not . isIntBuiltin)
.&&. (not . isCastBuiltin)
.&&. (/= BuiltinMonadBind)
.&&. (/= BuiltinIsEqual)
$ f

explicit :: Bool
Expand Down Expand Up @@ -562,6 +573,8 @@ isIgnoredBuiltin f
BuiltinNatLe -> False
BuiltinNatLt -> False
BuiltinNatEq -> False
-- Eq
BuiltinIsEqual -> False
-- Monad
BuiltinMonadBind -> False
-- Ignored
Expand Down
18 changes: 11 additions & 7 deletions src/Juvix/Compiler/Concrete/Data/InfoTableBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data InfoTableBuilder :: Effect where
RegisterRecordInfo :: S.NameId -> RecordInfo -> InfoTableBuilder m ()
RegisterAlias :: S.NameId -> PreSymbolEntry -> InfoTableBuilder m ()
RegisterLocalModule :: ScopedModule -> InfoTableBuilder m ()
GetInfoTable :: InfoTableBuilder m InfoTable
GetBuilderInfoTable :: InfoTableBuilder m InfoTable
GetBuiltinSymbol' :: Interval -> BuiltinPrim -> InfoTableBuilder m S.Symbol
RegisterBuiltin' :: BuiltinPrim -> S.Symbol -> InfoTableBuilder m ()

Expand Down Expand Up @@ -92,7 +92,7 @@ runInfoTableBuilder ini = reinterpret (runState ini) $ \case
modify (over infoScoperAlias (HashMap.insert uid a))
RegisterLocalModule m ->
mapM_ (uncurry registerBuiltinHelper) (m ^. scopedModuleInfoTable . infoBuiltins . to HashMap.toList)
GetInfoTable ->
GetBuilderInfoTable ->
get
GetBuiltinSymbol' i b -> do
tbl <- get @InfoTable
Expand Down Expand Up @@ -154,15 +154,19 @@ anameFromScopedIden s =
}

lookupInfo :: (Members '[InfoTableBuilder, Reader InfoTable] r) => (InfoTable -> Maybe a) -> Sem r a
lookupInfo f = do
tab1 <- ask
fromMaybe (fromJust (f tab1)) . f <$> getInfoTable
lookupInfo f = fromJust <$> lookupInfo' f

lookupInfo' :: (Members '[InfoTableBuilder, Reader InfoTable] r) => (InfoTable -> Maybe a) -> Sem r (Maybe a)
lookupInfo' f = do
tab1 <- getBuilderInfoTable
tab2 <- ask
return (f tab1 <|> f tab2)

lookupFixity :: (Members '[InfoTableBuilder, Reader InfoTable] r) => S.NameId -> Sem r FixityDef
lookupFixity uid = lookupInfo (HashMap.lookup uid . (^. infoFixities))
lookupFixity uid = lookupInfo (^. infoFixities . at uid)

getPrecedenceGraph :: (Members '[InfoTableBuilder, Reader InfoTable] r) => Sem r PrecedenceGraph
getPrecedenceGraph = do
tab <- ask
tab' <- getInfoTable
tab' <- getBuilderInfoTable
return $ combinePrecedenceGraphs (tab ^. infoPrecedenceGraph) (tab' ^. infoPrecedenceGraph)
6 changes: 5 additions & 1 deletion src/Juvix/Compiler/Concrete/Data/NameSignature/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ where

import Juvix.Compiler.Concrete.Data.NameSignature.Error
import Juvix.Compiler.Concrete.Gen qualified as Gen
import Juvix.Compiler.Concrete.Language.Base
import Juvix.Compiler.Concrete.Translation.FromParsed.Analysis.Scoping.Error
import Juvix.Prelude

Expand Down Expand Up @@ -63,8 +64,11 @@ instance (SingI s) => HasNameSignature s (AxiomDef s) where
addArgs :: (Members '[NameSignatureBuilder s] r) => AxiomDef s -> Sem r ()
addArgs a = addArgs (a ^. axiomTypeSig)

instance (SingI s) => HasNameSignature s (FunctionLhs s) where
addArgs FunctionLhs {..} = addArgs _funLhsTypeSig

instance (SingI s) => HasNameSignature s (FunctionDef s) where
addArgs a = addArgs (a ^. signTypeSig)
addArgs = addArgs . functionDefLhs

instance (SingI s) => HasNameSignature s (InductiveDef s, ConstructorDef s) where
addArgs ::
Expand Down
11 changes: 8 additions & 3 deletions src/Juvix/Compiler/Concrete/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module Juvix.Compiler.Concrete.Extra
getPatternAtomIden,
isBodyExpression,
isFunctionLike,
isLhsFunctionLike,
symbolParsed,
)
where
Expand Down Expand Up @@ -46,6 +47,7 @@ groupStatements = \case
-- blank line
g :: Statement s -> Statement s -> Bool
g a b = case (a, b) of
(StatementDeriving _, _) -> False
(StatementSyntax _, StatementSyntax _) -> True
(StatementSyntax (SyntaxFixity _), _) -> False
(StatementSyntax (SyntaxOperator o), s) -> definesSymbol (o ^. opSymbol) s
Expand Down Expand Up @@ -108,6 +110,9 @@ isBodyExpression = \case
SigBodyExpression {} -> True
SigBodyClauses {} -> False

isFunctionLike :: FunctionDef a -> Bool
isFunctionLike = \case
FunctionDef {..} -> not (null (_signTypeSig ^. typeSigArgs)) || not (isBodyExpression _signBody)
isLhsFunctionLike :: FunctionLhs 'Parsed -> Bool
isLhsFunctionLike FunctionLhs {..} = notNull (_funLhsTypeSig ^. typeSigArgs)

isFunctionLike :: FunctionDef 'Parsed -> Bool
isFunctionLike d@FunctionDef {..} =
isLhsFunctionLike (functionDefLhs d) || (not . isBodyExpression) _signBody
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Concrete/Keywords.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import Juvix.Data.Keyword.All
kwCase,
kwCoercion,
kwColon,
kwDeriving,
kwDo,
kwElse,
kwEnd,
Expand Down Expand Up @@ -85,6 +86,7 @@ reservedKeywords :: [Keyword]
reservedKeywords =
[ delimSemicolon,
kwAssign,
kwDeriving,
kwAt,
kwAtQuestion,
kwAxiom,
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Concrete/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ statementLabel = \case
StatementOpenModule {} -> Nothing
StatementProjectionDef {} -> Nothing
StatementFunctionDef f -> Just (f ^. signName . symbolTypeLabel)
StatementDeriving f -> Just (f ^. derivingFunLhs . funLhsName . symbolTypeLabel)
StatementImport i -> Just (i ^. importModulePath . to modulePathTypeLabel)
StatementInductive i -> Just (i ^. inductiveName . symbolTypeLabel)
StatementModule i -> Just (i ^. modulePath . to modulePathTypeLabel)
Expand Down
Loading

0 comments on commit c100812

Please sign in to comment.