Skip to content

Commit

Permalink
Support type synonyms in instance types (#2772)
Browse files Browse the repository at this point in the history
* Closes #2358
  • Loading branch information
lukaszcz authored May 15, 2024
1 parent 47b3b19 commit 325d43f
Show file tree
Hide file tree
Showing 30 changed files with 498 additions and 335 deletions.
4 changes: 2 additions & 2 deletions app/Commands/Dev/Termination/CallGraph.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import Commands.Base
import Commands.Dev.Termination.CallGraph.Options
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Internal.Pretty qualified as Internal
import Juvix.Compiler.Internal.Translation.FromConcrete.Data.Context qualified as Internal
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination qualified as Termination
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking.Data.Context qualified as Internal
import Juvix.Compiler.Store.Extra qualified as Stored
import Juvix.Prelude.Pretty

runCommand :: (Members '[EmbedIO, TaggedLock, App] r) => CallGraphOptions -> Sem r ()
runCommand CallGraphOptions {..} = do
globalOpts <- askGlobalOptions
PipelineResult {..} <- runPipelineTermination _graphInputFile upToInternal
PipelineResult {..} <- runPipelineTermination _graphInputFile upToInternalTyped
let mainModule = _pipelineResult ^. Internal.resultModule
toAnsiText' :: forall a. (HasAnsiBackend a, HasTextBackend a) => a -> Text
toAnsiText' = toAnsiText (not (globalOpts ^. globalNoColors))
Expand Down
53 changes: 25 additions & 28 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fromInternal i = do
}
res <-
execInfoTableBuilder md
. evalState (i ^. InternalTyped.resultFunctions)
. runReader (i ^. InternalTyped.resultFunctions)
. runReader (i ^. InternalTyped.resultIdenTypes)
$ do
when
Expand Down Expand Up @@ -95,13 +95,13 @@ fromInternalExpression importTab res exp = do
fmap snd
. runReader mtab
. runInfoTableBuilder (res ^. coreResultModule)
. evalState (res ^. coreResultInternalTypedResult . InternalTyped.resultFunctions)
. runReader (res ^. coreResultInternalTypedResult . InternalTyped.resultFunctions)
. runReader (res ^. coreResultInternalTypedResult . InternalTyped.resultIdenTypes)
$ fromTopIndex (goExpression exp)

goModule ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
Internal.Module ->
Sem r ()
goModule m = do
Expand All @@ -110,7 +110,7 @@ goModule m = do
-- | predefine an inductive definition
preInductiveDef ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
Internal.InductiveDef ->
Sem r PreInductiveDef
preInductiveDef i = do
Expand Down Expand Up @@ -149,7 +149,7 @@ preInductiveDef i = do

goInductiveDef ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
PreInductiveDef ->
Sem r ()
goInductiveDef PreInductiveDef {..} = do
Expand All @@ -163,7 +163,7 @@ goInductiveDef PreInductiveDef {..} = do

goConstructor ::
forall r.
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, NameIdGen] r) =>
Symbol ->
Internal.ConstructorDef ->
Sem r ConstructorInfo
Expand Down Expand Up @@ -228,7 +228,7 @@ goConstructor sym ctor = do

goMutualBlock ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
Internal.MutualBlock ->
Sem r ()
goMutualBlock (Internal.MutualBlock m) = preMutual m >>= goMutual
Expand Down Expand Up @@ -267,7 +267,7 @@ goMutualBlock (Internal.MutualBlock m) = preMutual m >>= goMutual

preFunctionDef ::
forall r.
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, NameIdGen] r) =>
Internal.FunctionDef ->
Sem r PreFunctionDef
preFunctionDef f = do
Expand Down Expand Up @@ -336,7 +336,7 @@ preFunctionDef f = do

goFunctionDef ::
forall r.
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, NameIdGen] r) =>
PreFunctionDef ->
Sem r ()
goFunctionDef PreFunctionDef {..} = do
Expand All @@ -355,21 +355,18 @@ goFunctionDef PreFunctionDef {..} = do
let (is, _) = unfoldLambdas node
setIdentArgs _preFunSym (map (^. lambdaLhsBinder) is)

strongNormalizeHelper :: (Members '[State InternalTyped.FunctionsTable, NameIdGen] r) => Internal.Expression -> Sem r Internal.Expression
strongNormalizeHelper ty = evalState InternalTyped.iniState (InternalTyped.strongNormalize' ty)

goType ::
forall r.
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, NameIdGen, Reader IndexTable] r) =>
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, NameIdGen, Reader IndexTable] r) =>
Internal.Expression ->
Sem r Type
goType ty = do
normTy <- strongNormalizeHelper ty
normTy <- InternalTyped.strongNormalize'' ty
squashApps <$> goExpression normTy

mkFunBody ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Type -> -- converted type of the function
Internal.FunctionDef ->
Sem r Node
Expand All @@ -381,7 +378,7 @@ mkFunBody ty f =

mkBody ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Type -> -- type of the function
Location ->
NonEmpty ([Internal.PatternArg], Internal.Expression) ->
Expand Down Expand Up @@ -468,7 +465,7 @@ mkBody ty loc clauses

goCase ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Internal.Case ->
Sem r Node
goCase c = do
Expand Down Expand Up @@ -502,7 +499,7 @@ goCase c = do

goLambda ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Internal.Lambda ->
Sem r Node
goLambda l = do
Expand All @@ -511,7 +508,7 @@ goLambda l = do

goLet ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Internal.Let ->
Sem r Node
goLet l = goClauses (toList (l ^. Internal.letClauses))
Expand Down Expand Up @@ -549,7 +546,7 @@ goLet l = goClauses (toList (l ^. Internal.letClauses))

goAxiomInductive ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
Internal.AxiomDef ->
Sem r ()
goAxiomInductive a = whenJust (a ^. Internal.axiomBuiltin) builtinInductive
Expand Down Expand Up @@ -612,7 +609,7 @@ fromTopIndex = runReader initIndexTable

goAxiomDef ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, NameIdGen] r) =>
Internal.AxiomDef ->
Sem r ()
goAxiomDef a = maybe goAxiomNotBuiltin builtinBody (a ^. Internal.axiomBuiltin)
Expand Down Expand Up @@ -767,7 +764,7 @@ goAxiomDef a = maybe goAxiomNotBuiltin builtinBody (a ^. Internal.axiomBuiltin)

fromPatternArg ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, State IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, State IndexTable, NameIdGen] r) =>
Internal.PatternArg ->
Sem r Pattern
fromPatternArg pa = case pa ^. Internal.patternArgName of
Expand Down Expand Up @@ -851,7 +848,7 @@ fromPatternArg pa = case pa ^. Internal.patternArgName of

goPatternArgs ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Level -> -- the level of the first binder for the matched value
Internal.Expression ->
[Internal.PatternArg] ->
Expand Down Expand Up @@ -909,7 +906,7 @@ addPatternVariableNames p lvl vars =

goIden ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable] r) =>
Internal.Iden ->
Sem r Node
goIden i = do
Expand Down Expand Up @@ -988,7 +985,7 @@ goIden i = do

goExpression ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Internal.Expression ->
Sem r Node
goExpression = \case
Expand All @@ -1008,7 +1005,7 @@ goExpression = \case

goFunction ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
([Internal.FunctionParameter], Internal.Expression) ->
Sem r Node
goFunction (params, returnTypeExpr) = go params
Expand All @@ -1031,7 +1028,7 @@ goFunction (params, returnTypeExpr) = go params

goSimpleLambda ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Internal.SimpleLambda ->
Sem r Node
goSimpleLambda l = do
Expand All @@ -1043,7 +1040,7 @@ goSimpleLambda l = do

goApplication ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen] r) =>
Internal.Application ->
Sem r Node
goApplication a = do
Expand Down
49 changes: 11 additions & 38 deletions src/Juvix/Compiler/Internal/Data/InfoTable.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Internal.Data.InfoTable
( module Juvix.Compiler.Store.Internal.Language,
computeInternalModule,
computeInternalModuleInfoTable,
extendWithReplExpression,
lookupConstructor,
lookupConstructorArgTypes,
Expand All @@ -21,9 +22,9 @@ where

import Data.Generics.Uniplate.Data
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Internal.Data.CoercionInfo
import Juvix.Compiler.Internal.Data.InstanceInfo
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Extra.CoercionInfo
import Juvix.Compiler.Internal.Extra.InstanceInfo
import Juvix.Compiler.Internal.Pretty (ppTrace)
import Juvix.Compiler.Store.Internal.Data.FunctionsTable
import Juvix.Compiler.Store.Internal.Data.TypesTable
Expand Down Expand Up @@ -80,19 +81,21 @@ letFunctionDefs e =
LetFunDef f -> pure f
LetMutualBlock (MutualBlockLet fs) -> fs

computeInternalModule :: TypesTable -> FunctionsTable -> Module -> InternalModule
computeInternalModule tysTab funsTab m@Module {..} =
computeInternalModule :: InstanceTable -> CoercionTable -> TypesTable -> FunctionsTable -> Module -> InternalModule
computeInternalModule instTab coeTab tysTab funsTab m@Module {..} =
InternalModule
{ _internalModuleId = _moduleId,
_internalModuleName = _moduleName,
_internalModuleImports = _moduleBody ^. moduleImports,
_internalModuleInfoTable = computeInfoTable m,
_internalModuleInfoTable = computeInternalModuleInfoTable m,
_internalModuleTypesTable = tysTab,
_internalModuleFunctionsTable = funsTab
_internalModuleFunctionsTable = funsTab,
_internalModuleInstanceTable = instTab,
_internalModuleCoercionTable = coeTab
}

computeInfoTable :: Module -> InfoTable
computeInfoTable m = InfoTable {..}
computeInternalModuleInfoTable :: Module -> InfoTable
computeInternalModuleInfoTable m = InfoTable {..}
where
mutuals :: [MutualStatement]
mutuals =
Expand Down Expand Up @@ -168,36 +171,6 @@ computeInfoTable m = InfoTable {..}
_axiomInfoDef ^. axiomBuiltin
>>= (\b -> Just (BuiltinsAxiom b, _axiomInfoDef ^. axiomName))

_infoInstances :: InstanceTable
_infoInstances = foldr (flip updateInstanceTable) mempty $ mapMaybe mkInstance (HashMap.elems _infoFunctions)
where
mkInstance :: FunctionInfo -> Maybe InstanceInfo
mkInstance (FunctionInfo {..})
| _functionInfoInstance =
instanceFromTypedExpression
( TypedExpression
{ _typedType = _functionInfoType,
_typedExpression = ExpressionIden (IdenFunction _functionInfoName)
}
)
| otherwise =
Nothing

_infoCoercions :: CoercionTable
_infoCoercions = foldr (flip updateCoercionTable) mempty $ mapMaybe mkCoercion (HashMap.elems _infoFunctions)
where
mkCoercion :: FunctionInfo -> Maybe CoercionInfo
mkCoercion (FunctionInfo {..})
| _functionInfoCoercion =
coercionFromTypedExpression
( TypedExpression
{ _typedType = _functionInfoType,
_typedExpression = ExpressionIden (IdenFunction _functionInfoName)
}
)
| otherwise =
Nothing

ss :: [MutualBlock]
ss = m ^. moduleBody . moduleStatements

Expand Down
Original file line number Diff line number Diff line change
@@ -1,50 +1,18 @@
module Juvix.Compiler.Internal.Data.CoercionInfo where
module Juvix.Compiler.Internal.Extra.CoercionInfo
( module Juvix.Compiler.Store.Internal.Data.CoercionInfo,
module Juvix.Compiler.Internal.Extra.CoercionInfo,
)
where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Data.List qualified as List
import Juvix.Compiler.Internal.Data.InstanceInfo
import Juvix.Compiler.Internal.Extra.Base
import Juvix.Compiler.Internal.Extra.InstanceInfo
import Juvix.Compiler.Internal.Language
import Juvix.Extra.Serialize
import Juvix.Compiler.Store.Internal.Data.CoercionInfo
import Juvix.Prelude

data CoercionInfo = CoercionInfo
{ _coercionInfoInductive :: Name,
_coercionInfoParams :: [InstanceParam],
_coercionInfoTarget :: InstanceApp,
_coercionInfoResult :: Expression,
_coercionInfoArgs :: [FunctionParameter]
}
deriving stock (Eq, Generic)

instance Hashable CoercionInfo where
hashWithSalt salt CoercionInfo {..} = hashWithSalt salt _coercionInfoResult

instance Serialize CoercionInfo

-- | Maps trait names to available coercions
newtype CoercionTable = CoercionTable
{ _coercionTableMap :: HashMap InductiveName [CoercionInfo]
}
deriving stock (Eq, Generic)

instance Serialize CoercionTable

makeLenses ''CoercionInfo
makeLenses ''CoercionTable

instance Semigroup CoercionTable where
t1 <> t2 =
CoercionTable $
HashMap.unionWith combine (t1 ^. coercionTableMap) (t2 ^. coercionTableMap)
where
combine :: [CoercionInfo] -> [CoercionInfo] -> [CoercionInfo]
combine ii1 ii2 = nubHashable (ii1 ++ ii2)

instance Monoid CoercionTable where
mempty = CoercionTable mempty

updateCoercionTable :: CoercionTable -> CoercionInfo -> CoercionTable
updateCoercionTable tab ci@CoercionInfo {..} =
over coercionTableMap (HashMap.alter go _coercionInfoInductive) tab
Expand Down
Loading

0 comments on commit 325d43f

Please sign in to comment.