Skip to content
This repository has been archived by the owner on Oct 18, 2021. It is now read-only.

Commit

Permalink
Quickly look at applications
Browse files Browse the repository at this point in the history
  • Loading branch information
Matheus Magalhães de Alcantara committed Oct 19, 2019
1 parent 7304991 commit 80461c9
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 84 deletions.
151 changes: 105 additions & 46 deletions src/Types/Infer/App.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
{-# LANGUAGE FlexibleContexts, TupleSections, ScopedTypeVariables,
TypeFamilies, CPP, StandaloneDeriving, UndecidableInstances #-}

-- | This module implements the "Quick Look" impredicative polymorphism.
-- At a glance: Application spines are traversed twice, the first time
-- to collect impredicative instantiations (this is the "quick look"
-- pass, which is so called because it ignores hard expressions - @let@,
-- @match@ etc) and the second time to type-check as usual.
module Types.Infer.App (inferApps) where

import Prelude
Expand All @@ -16,12 +22,14 @@ import Data.Maybe
import Control.Monad.Infer
import Control.Lens

import Syntax.Builtin
import Syntax.Subst
import Syntax.Types
import Syntax.Var
import Syntax

import Types.Infer.Builtin
import {-# SOURCE #-} Types.Infer
import Types.Infer.Builtin
import Types.Kinds
import Types.Unify

Expand Down Expand Up @@ -54,6 +62,7 @@ inferApps exp expected =
traceM ("function type: " ++ displayS (pretty function_t))
#endif

-- Pass 1:
((ql_sub, quantifiers), cs) <-
censor (const mempty) . listen $
quickLook function_t =<< for arguments (\a@(x, y) -> (x, y,) <$> inferQL a)
Expand All @@ -69,11 +78,16 @@ inferApps exp expected =

tell cs

ty_syms <- view tySyms

r_ql_sub <- case expected of
Just tau@(TyApps (TyCon _) (_:_)) | result@(TyApps (TyCon _) (_:_)) <- getQuantR quantifiers -> do
-- Pushing down the result type into the quick look substitution
-- is only sound if both are /guarded/. Do note that missing a
-- substitution here isn't unsound, it'll just lead to wonky inference.
Just tau@(TyApps (TyCon t) (_:_))
| result@(TyApps (TyCon t') (_:_)) <- getQuantR quantifiers
, invariant ty_syms t, invariant ty_syms t' -> do
-- Pushing down the result type into the quick look substitution
-- is only sound if both are /guarded/, i.e. applications of an
-- (invariant) constructor. Do note that missing a substitution
-- here isn't unsound, it'll just lead to wonky inference.
_ <- subsumes (becauseExp exp) result (apply ql_sub tau)
pure $ fromMaybe mempty (unifyPure result tau)
_ -> pure mempty
Expand All @@ -99,48 +113,60 @@ inferApps exp expected =
Nothing -> pure id

pure (wrap (foldr (.) id (reverse arg_ks) function), result)
where
spine ex@(App fn arg _) =
let sp = spine fn
in (ExprArg arg, BecauseOf ex):sp
spine ex@(Vta fn arg _) =
let sp = spine fn
in (TypeArg arg, BecauseOf ex):sp
spine ex = [(ExprArg ex, BecauseOf ex)]

checkArguments ((ExprArg arg, _):as) (Quant tau dom cod inst_cont qs) =
case dom of
Anon dom -> do
x <- check arg dom

let cont ex = App ex x (annotation ex <> annotation x, cod)

(conts, result) <- checkArguments as qs
pure (inst_cont:cont:conts, result)
Invisible _ _ Req -> do
(_, t) <- infer arg
b <- freshTV
confesses (NotEqual tau (TyArr t b))
_ -> error "checkArguments ExprArg: impossible quantifier"

checkArguments ((TypeArg arg, reason):as) (Quant tau dom cod inst_cont qs) =
case dom of
Invisible v kind r | r /= Infer{} -> do
arg <- case kind of
Just k -> checkAgainstKind reason arg k
Nothing -> resolveKind reason arg

let ty = apply (Map.singleton v arg) cod
cont ex = ExprWrapper (TypeApp arg) ex (annotation ex <> annotation reason, ty)

(conts, result) <- checkArguments as qs
pure (inst_cont:cont:conts, result)
_ -> confesses (ArisingFrom (CanNotVta tau arg) reason)

checkArguments [] Quant{} = error "arity mismatch. impossible in checkArguments"

checkArguments _ (Result tau) = pure ([], tau)

spine :: (Ann p ~ Ann Resolved, Var p ~ Var Resolved) => Expr p -> [(Arg p, SomeReason)]
spine ex@(App fn arg _) =
let sp = spine fn
in (ExprArg arg, BecauseOf ex):sp
spine ex@(Vta fn arg _) =
let sp = spine fn
in (TypeArg arg, BecauseOf ex):sp
spine ex = [(ExprArg ex, BecauseOf ex)]

-- | Check the given 'Arg's against some 'Quantifiers', returning a set
-- of suspended 'App'/'Vta's and the result 'Type' of the expression.
checkArguments :: MonadInfer Typed m
=> [(Arg Desugared, SomeReason)]
-> Quantifiers Typed
-> m ( [ Expr Typed -> Expr Typed]
, Type Typed )
checkArguments ((ExprArg arg, _):as) (Quant tau dom cod inst_cont qs) =
case dom of
Anon dom -> do
x <- check arg dom

let cont ex = App ex x (annotation ex <> annotation x, cod)

(conts, result) <- checkArguments as qs
pure (inst_cont:cont:conts, result)
Invisible _ _ Req -> do
(_, t) <- infer arg
b <- freshTV
confesses (NotEqual tau (TyArr t b))
_ -> error "checkArguments ExprArg: impossible quantifier"

checkArguments ((TypeArg arg, reason):as) (Quant tau dom cod inst_cont qs) =
case dom of
Invisible v kind r | r /= Infer{} -> do
arg <- case kind of
Just k -> checkAgainstKind reason arg k
Nothing -> resolveKind reason arg

let ty = apply (Map.singleton v arg) cod
cont ex = ExprWrapper (TypeApp arg) ex (annotation ex <> annotation reason, ty)

(conts, result) <- checkArguments as qs
pure (inst_cont:cont:conts, result)
_ -> confesses (ArisingFrom (CanNotVta tau arg) reason)

checkArguments [] Quant{} = error "arity mismatch. impossible in checkArguments"

checkArguments _ (Result tau) = pure ([], tau)

-- | Perform the "quick look" pass, exposing as many quantifiers in the
-- function's type as there are 'Arg's given, returning a substitution
-- with impredicative instantiation and a linked data structure
-- representing the quantifiers.
quickLook :: MonadInfer Typed m
=> Type Typed
-> [(Arg Desugared, SomeReason, Maybe (Type Typed))]
Expand Down Expand Up @@ -172,22 +198,52 @@ quickLook t ((TypeArg tau, reason, _):args) = do
pure (sub, Quant t dom cod wrap qs)
quickLook tau [] = pure (mempty, Result tau)

-- | Return the impredicative instantiation from quickly looking at this
-- expression.
inferQL :: MonadInfer Typed m => (Arg Desugared, SomeReason) -> m (Maybe (Type Typed))
inferQL (arg, reason) = case arg of
ExprArg a -> inferQL_ex a
TypeArg tau -> pure <$> liftType reason tau

-- | Look at an expression quickly.
inferQL_ex :: MonadInfer Typed m => Expr Desugared -> m (Maybe (Type Typed))

#ifdef TRACE_TC
inferQL_ex ex | trace ("looking quickly at " ++ displayS (pretty ex)) False = undefined
#endif

inferQL_ex ex@(VarRef x _) = do
(_, _, (new, _)) <- third3A (discharge ex) =<< lookupTy' Strong x
(_, tau) <- censor (const mempty) $ instantiateTc (BecauseOf ex) new
if hasPoly tau
then pure (pure tau)
else pure Nothing

-- Look at expressions quickly. Here, we assume that traversing
-- expressions is "cheap", but solving constraints is "expensive". Since
-- the second pass will solve constraints anyway, we /don't need any
-- solving from the quick look pass/.
inferQL_ex ex@App{} = quiet $ do
t <- snd <$> inferApps ex Nothing
pure . snd <$> instantiateTc (BecauseOf ex) t
inferQL_ex ex@Vta{} = quiet $ do
t <- snd <$> inferApps ex Nothing
pure . snd <$> instantiateTc (BecauseOf ex) t
inferQL_ex ex@(BinOp l o r a) = quiet $ do
t <- snd <$> inferApps (App (App o l a) r a) Nothing
pure . snd <$> instantiateTc (BecauseOf ex) t

inferQL_ex (Literal l _) = pure (pure (litTy l))
inferQL_ex ex@(Ascription _ t _) = pure <$> liftType (BecauseOf ex) t
inferQL_ex _ = pure Nothing

-- | Is this type constructor invariant in its arguments?
invariant :: Map.Map (Var Typed) TySymInfo -> Var Typed -> Bool
invariant syms x =
x /= tyArrowName -- The function type is co/contravariant
&& x /= tyTupleName -- The tuple type is covariant
&& x `Map.notMember` syms -- Type families have weird variance that can't be solved quickly

data Arg p = ExprArg (Expr p) | TypeArg (Type p)
deriving instance (Show (Var p), Show (Ann p)) => Show (Arg p)

Expand Down Expand Up @@ -222,3 +278,6 @@ hasPoly :: Type Typed -> Bool
hasPoly = any isForall . universe where
isForall (TyPi Invisible{} _) = True
isForall _ = False

quiet :: MonadWriter w m => m a -> m a
quiet = censor (const mempty)
2 changes: 1 addition & 1 deletion tests/lua/default-method.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
do
local use = print
use(function(fw) return "tail" .. "()" end)
use(function(ge) return "tail" .. "()" end)
end
16 changes: 8 additions & 8 deletions tests/lua/monoid.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ do
local Nil = { __tag = "Nil" }
local tostring = tostring
local writeln = print
local function _dollardApplicativeako(tmp)
local function _dollardApplicativealc(tmp)
return {
pure = function(tmp0) return tmp.zero end,
["<*>"] = function(tmp0) return function(tmp1) return tmp["×"](tmp0)(tmp1) end end,
Expand All @@ -12,27 +12,27 @@ do
local function _colon_colon(x)
return function(y) return { { _1 = x, _2 = y }, __tag = "Cons" } end
end
local function _dollarshow(bam, x)
local function _dollarshow(bbv, x)
if x.__tag == "Nil" then return "Nil" end
local tmp = x[1]
return bam(tmp._1) .. " :: " .. _dollarshow(bam, tmp._2)
return bbv(tmp._1) .. " :: " .. _dollarshow(bbv, tmp._2)
end
local function _dollartraverse(brr, tmp, k, x)
local function _dollartraverse(bui, tmp, k, x)
if x.__tag == "Nil" then return tmp.pure(Nil) end
local tmp0 = x[1]
return tmp["<*>"](tmp["Applicative$kb"](_colon_colon)(k(tmp0._1)))(_dollartraverse(nil, tmp, k, tmp0._2))
end
local function _dollar_d7(bxc, x, ys)
local function _dollar_d7(cbj, x, ys)
if x.__tag == "Nil" then return ys end
local tmp = x[1]
return { { _2 = _dollar_d7(nil, tmp._2, ys), _1 = tmp._1 }, __tag = "Cons" }
end
local tmp = { _1 = 1, _2 = nil }
writeln(_dollarshow(function(x)
return tostring(x)
end, _dollartraverse(nil, _dollardApplicativeako({
zero = Nil,
["×"] = function(x) return function(ys) return _dollar_d7(nil, x, ys) end end
end, _dollartraverse(nil, _dollardApplicativealc({
["×"] = function(x) return function(ys) return _dollar_d7(nil, x, ys) end end,
zero = Nil
}), function(tmp0) return { { _1 = tmp0._1, _2 = Nil }, __tag = "Cons" } end, {
{
_1 = tmp,
Expand Down
12 changes: 6 additions & 6 deletions tests/lua/nested_match.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ do
if xs.__tag == "Nil" then return { { _1 = 1, _2 = Nil }, __tag = "Cons" } end
local tmp = xs[1]
if ys.__tag == "Nil" then return { { _1 = 2, _2 = Nil }, __tag = "Cons" } end
local tmp2 = ys[1]
local tmp3, tmp4 = tmp2._1, tmp2._2
local tmp0, tmp1 = tmp._1, tmp._2
if tmp0 ~= 0 then return { { _1 = f(tmp0)(tmp3), _2 = zip(f, tmp1, tmp4) }, __tag = "Cons" } end
if tmp3 == 0 then return { { _1 = 3, _2 = Nil }, __tag = "Cons" } end
return { { _1 = f(0)(tmp3), _2 = zip(f, tmp1, tmp4) }, __tag = "Cons" }
local tmp3, tmp4 = tmp._1, tmp._2
local tmp0 = ys[1]
local tmp1, tmp2 = tmp0._1, tmp0._2
if tmp3 ~= 0 then return { { _1 = f(tmp3)(tmp1), _2 = zip(f, tmp4, tmp2) }, __tag = "Cons" } end
if tmp1 == 0 then return { { _1 = 3, _2 = Nil }, __tag = "Cons" } end
return { { _1 = f(0)(tmp1), _2 = zip(f, tmp4, tmp2) }, __tag = "Cons" }
end
local function zip0(f) return function(xs) return function(ys) return zip(f, xs, ys) end end end
(nil)(zip0)
Expand Down
12 changes: 6 additions & 6 deletions tests/lua/stream-zip.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,27 @@ do
return Skip({ _2 = { _1 = sb, _2 = None }, _1 = tmp5[1] })
elseif tmp5.__tag == "Yield" then
local tmp6 = tmp5[1]
return Skip({ _1 = tmp6._2, _2 = { _1 = sb, _2 = Some(tmp6._1) } })
return Skip({ _2 = { _1 = sb, _2 = Some(tmp6._1) }, _1 = tmp6._2 })
elseif tmp5.__tag == "Done" then
return Done
end
else
local x0 = x[1]
local tmp5 = g(sb)
local x0 = x[1]
if tmp5.__tag == "Skip" then
return Skip({ _1 = sa, _2 = { _2 = Some(x0), _1 = tmp5[1] } })
return Skip({ _1 = sa, _2 = { _1 = tmp5[1], _2 = Some(x0) } })
elseif tmp5.__tag == "Yield" then
local tmp6 = tmp5[1]
return Yield({
_1 = { _1 = x0, _2 = tmp6._1 },
_2 = { _1 = sa, _2 = { _2 = None, _1 = tmp6._2 } }
_2 = { _1 = sa, _2 = { _2 = None, _1 = tmp6._2 } },
_1 = { _1 = x0, _2 = tmp6._1 }
})
elseif tmp5.__tag == "Done" then
return Done
end
end
end,
_2 = { _1 = start, _2 = { _1 = tmp2._2, _2 = None } }
_2 = { _1 = start, _2 = { _2 = None, _1 = tmp2._2 } }
})
end
end
Expand Down
6 changes: 3 additions & 3 deletions tests/lua/values_occ.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ do
local tmp = xs[1]
local go = tmp._1
return Stream({
_2 = tmp._2,
_1 = function(st)
local tmp0 = go(st)
local x = tmp0._1
return { _2 = tmp0._2, _1 = x * x }
end,
_2 = tmp._2
return { _1 = x * x, _2 = tmp0._2 }
end
})
end
print(to_string(sum_squares))
Expand Down
2 changes: 1 addition & 1 deletion tests/types/class/fundep01.out
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ gshow : Infer{'id : type}. ('id -> type) -> constraint
gshow : Spec{'f : 'id -> type}. gshow 'f => Infer{'id : type}. Spec{'x : 'id}. 'f 'x -> string
show : type -> constraint
show : Spec{'a : type}. show 'a => 'a -> string
genericShow : Infer{'a : type}. Infer{'f : 'ad -> type}. gshow 'f => generic 'a 'f => 'a -> string
genericShow : Infer{'a : type}. Infer{'rep : 'ad -> type}. gshow 'rep => generic 'a 'rep => 'a -> string
14 changes: 4 additions & 10 deletions tests/types/gadt/dict02-fail.out
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
dict02-fail.ml[17:17 ..17:22]: error
dict02-fail.ml[17:25 ..17:26]: error
17 │ let x = with_d (bar ()) () (fun x -> show x)
│ ^^^^^^
Couldn't match actual type dict (show int)
with the type expected by the context, dict ('c unit)
dict02-fail.ml[17:17 ..17:22]: error
17 │ let x = with_d (bar ()) () (fun x -> show x)
│ ^^^^^^
Couldn't match actual type dict (show int)
with the type expected by the context, dict ('c unit)
│ ^^
Couldn't match actual type unit
with the type expected by the context, int
2 changes: 1 addition & 1 deletion tests/types/gadt/fail_term.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fail_term.ml[16:3 ..16:30]: error
No instance for 'a -> 'b -> 'b ~ 'b -> 'y -> 'b arising in the expression
No instance for 'a -> 'b -> 'b ~ 'x -> 'y -> 'x arising in the expression
16 │ Lam (Lam (Var (There Here)))
│ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion tests/types/gadt/pass_term.out
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ term : type -> type -> type
Var : Spec{'ctx : type}. Spec{'ty : type}. elem 'ty 'ctx -> term 'ctx 'ty
Lam : Infer{'mu : type}. Spec{'ty : 'mu}. Spec{'a : type}. Spec{'b : type}. Spec{'ctx : type}. ('ty ~ 'a -> 'b) ⊃ term ('a * 'ctx) 'b -> term 'ctx 'ty
App : Infer{'ov : type}. Spec{'ty : 'ov}. Spec{'a : type}. Spec{'b : type}. Spec{'ctx : type}. ('ty ~ 'b) ⊃ (term 'ctx ('a -> 'b) * term 'ctx 'a) -> term 'ctx 'ty
const : Infer{'b : type}. Infer{'y : type}. Infer{'xs : type}. term 'xs ('b -> 'y -> 'b)
const : Infer{'x : type}. Infer{'y : type}. Infer{'xs : type}. term 'xs ('x -> 'y -> 'x)
2 changes: 1 addition & 1 deletion tests/types/unify.out
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ k1 : Infer{'a : type}. type -> 'a -> type
K1 : Spec{'c : type}. Spec{'a : 'a}. 'c -> k1 'c 'a
m1 : Infer{'a : type}. ('a -> type) -> 'a -> type
M1 : Spec{'f : 'a -> type}. Spec{'a : 'a}. 'f 'a -> m1 'f 'a
f : Infer{'a : type}. Infer{'b : 'a}. Infer{'g : 'a}. 'a -> m1 (sum (m1 (k1 'a)) 'g) 'b
f : Infer{'a : type}. Infer{'g : 'a}. Infer{'b : 'a}. 'a -> m1 (sum (m1 (k1 'a)) 'g) 'b

0 comments on commit 80461c9

Please sign in to comment.