Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set a limit on WHNF computation depth #187

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions rzk/src/Language/Rzk/Free/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,19 @@ termIsNF (Free (AnnF info f)) = t'
t' = Free (AnnF info { infoWHNF = Just t', infoNF = Just t' } f)

invalidateWHNF :: TermT var -> TermT var
invalidateWHNF = transFS $ \(AnnF info f) ->
AnnF info { infoWHNF = Nothing, infoNF = Nothing } f
invalidateWHNF tt = case tt of
-- type layer terms that should not be evaluated further
LambdaT{} -> tt
ReflT{} -> tt
TypeFunT{} -> tt
TypeSigmaT{} -> tt
TypeIdT{} -> tt
RecBottomT{} -> tt
TypeUnitT{} -> tt
UnitT{} -> tt

_ -> (`transFS` tt) $ \(AnnF info f) ->
AnnF info { infoWHNF = Nothing, infoNF = Nothing } f

substituteT :: TermT var -> Scope TermT var -> TermT var
substituteT x = substitute x . invalidateWHNF
Expand Down Expand Up @@ -288,7 +299,7 @@ toTerm bvars = go
Rzk.TypeSigma _loc pat tA tB ->
TypeSigma (patternVar pat) (go tA) (toScopePattern pat bvars tB)

Rzk.TypeSigmaTuple _loc (Rzk.SigmaParam _ patA tA) ((Rzk.SigmaParam _ patB tB) : ps) tN ->
Rzk.TypeSigmaTuple _loc (Rzk.SigmaParam _ patA tA) ((Rzk.SigmaParam _ patB tB) : ps) tN ->
go (Rzk.TypeSigmaTuple _loc (Rzk.SigmaParam _loc patX tX) ps tN)
where
patX = Rzk.PatternPair _loc patA patB
Expand All @@ -299,7 +310,7 @@ toTerm bvars = go
Rzk.Lambda _loc (Rzk.ParamPattern _ pat : params) body ->
Lambda (patternVar pat) Nothing (toScopePattern pat bvars (Rzk.Lambda _loc params body))
Rzk.Lambda _loc (Rzk.ParamPatternType _ [] _ty : params) body ->
go (Rzk.Lambda _loc params body)
go (Rzk.Lambda _loc params body)
Rzk.Lambda _loc (Rzk.ParamPatternType _ (pat:pats) ty : params) body ->
Lambda (patternVar pat) (Just (go ty, Nothing))
(toScopePattern pat bvars (Rzk.Lambda _loc (Rzk.ParamPatternType _loc pats ty : params) body))
Expand Down
21 changes: 20 additions & 1 deletion rzk/src/Rzk/TypeCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Language.Rzk.Free.Syntax
import qualified Language.Rzk.Syntax as Rzk

import Debug.Trace
import Text.Read (readMaybe)
import Unsafe.Coerce

-- $setup
Expand Down Expand Up @@ -241,6 +242,10 @@ setOption "render" = \case
"none" -> localRenderBackend Nothing
_ -> const $
issueTypeError $ TypeErrorOther "unknown render backend (use \"svg\", \"latex\", or \"none\")"
setOption "max-whnf-depth" = \case
str | Just n <- readMaybe str, n > 0 -> localMaxWhnfDepth n
_ -> const $
issueTypeError $ TypeErrorOther "invalid number (use any positive integer)"
setOption optionName = const $ const $
issueTypeError $ TypeErrorOther ("unknown option " <> show optionName)

Expand Down Expand Up @@ -645,6 +650,8 @@ data Context var = Context
, verbosity :: Verbosity
, covariance :: Covariance
, renderBackend :: Maybe RenderBackend
, maxWhnfDepth :: !Int
, whnfDepth :: !Int
} deriving (Functor, Foldable)

addVarInCurrentScope :: var -> VarInfo var -> Context var -> Context var
Expand All @@ -668,6 +675,8 @@ emptyContext = Context
, verbosity = Normal
, covariance = Covariant
, renderBackend = Nothing
, maxWhnfDepth = 50
, whnfDepth = 0
}

askCurrentScope :: TypeCheck var (ScopeInfo var)
Expand Down Expand Up @@ -1458,12 +1467,22 @@ tryRestriction = \case
go rs
_ -> pure Nothing

localMaxWhnfDepth :: Int -> TypeCheck var a -> TypeCheck var a
localMaxWhnfDepth n = local $ \ctx -> ctx { maxWhnfDepth = n }

incWhnfDepth :: a -> TypeCheck var a -> TypeCheck var a
incWhnfDepth def tc = do
Context{..} <- ask
if whnfDepth > maxWhnfDepth
then return def
else local (\ctx -> ctx { whnfDepth = whnfDepth + 1}) tc

-- | Compute a typed term to its WHNF.
--
-- >>> unsafeTypeCheck' $ whnfT "(\\ (x : Unit) -> x) unit"
-- unit : Unit
whnfT :: Eq var => TermT var -> TypeCheck var (TermT var)
whnfT tt = performing (ActionWHNF tt) $ case tt of
whnfT tt = incWhnfDepth tt $ case tt of
-- use cached result if it exists
Free (AnnF info _)
| Just tt' <- infoWHNF info -> pure tt'
Expand Down
Loading