Skip to content

Commit

Permalink
more tests, refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Dec 23, 2023
1 parent 4fed9c5 commit 5ee9cfd
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 44 deletions.
95 changes: 56 additions & 39 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,16 @@ object Infer {
def subsCheckRho(t: Type, rho: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] =
(t, rho) match {
case (fa: Type.Quantified, rho) =>
// Rule SPEC
for {
(exSkols, faRho) <- instantiate(fa)
unskol = unskolemizeExists(exSkols)
coerce <- subsCheckRho2(faRho, rho, left, right)
} yield coerce.andThen(unskol)
subsInstantiate(fa, rho, left, right) match {
case Some(inf) => inf
case None =>
// Rule SPEC
for {
(exSkols, faRho) <- instantiate(fa)
unskol = unskolemizeExists(exSkols)
coerce <- subsCheckRho2(faRho, rho, left, right)
} yield coerce.andThen(unskol)
}
// for existential lower bounds, we skolemize the existentials
// then verify they don't escape after inference and unskolemize
// them (if they are free in the resulting type)
Expand All @@ -616,7 +620,7 @@ object Infer {
(a1, r1) <- unifyFnRho(a2.length, rho1, left, right)
// since rho is in weak prenex form, and Fun is covariant on r2, we know
// r2 is in weak-prenex form and a rho type
rhor2 <- assertRho(r2, s"subsCheckRho2($t, $rho, $left, $right), line 521", right)
rhor2 <- assertRho(r2, s"subsCheckRho2($t, $rho, $left, $right), line 619", right)
coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right)
} yield coerce
case (Type.Fun(a1, r1), rho2) =>
Expand All @@ -625,7 +629,7 @@ object Infer {
(a2, r2) <- unifyFnRho(a1.length, rho2, right, left)
// since rho is in weak prenex form, and Fun is covariant on r2, we know
// r2 is in weak-prenex form
rhor2 <- assertRho(r2, s"subsCheckRho($t, $rho, $left, $right), line 471", right)
rhor2 <- assertRho(r2, s"subsCheckRho2($t, $rho, $left, $right), line 628", right)
coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right)
} yield coerce
case (rho1, ta@Type.TyApply(l2, r2)) =>
Expand Down Expand Up @@ -993,16 +997,38 @@ object Infer {
}
} yield res

// DEEP-SKOL rule
def subsInstantiate(inferred: Type, declared: Type, left: Region, right: Region): Option[Infer[TypedExpr.Coerce]] =
inferred match {
case Type.ForAll(vars, inT) =>
Type.instantiate(vars.iterator.toMap, inT, declared).map { case (_, subs) =>
validateSubs(subs.toList, left, right)
.as {
new FunctionK[TypedExpr, TypedExpr] {
def apply[A](te: TypedExpr[A]): TypedExpr[A] =
// we apply the annotation here and let Normalization
// instantiate. We could explicitly have
// instantiation TypedExpr where you pass the variables to set
TypedExpr.Annotation(te, declared)
}
}
}
case _ => None
}
// note, this is identical to subsCheckRho when declared is a Rho type
def subsCheck(inferred: Type, declared: Type, left: Region, right: Region): Infer[TypedExpr.Coerce] =
subsUpper[TypedExpr, cats.Id](declared, right, pure(inferred :: Nil)) { (_, rho) =>
// TODO: we are ignoring the metas, but we can't easily write them
// with the current design since Coerce can't do any Meta writing
subsCheckRho(inferred, rho, left, right)
} {
Error.SubsumptionCheckFailure(inferred, declared, left, right, _)
def subsCheck(inferred: Type, declared: Type, left: Region, right: Region): Infer[TypedExpr.Coerce] = {
subsInstantiate(inferred, declared, left, right) match {
case Some(inf) => inf
case None =>
// DEEP-SKOL rule
subsUpper[TypedExpr, cats.Id](declared, right, pure(inferred :: Nil)) { (_, rho) =>
// TODO: we are ignoring the metas, but we can't easily write them
// with the current design since Coerce can't do any Meta writing
subsCheckRho(inferred, rho, left, right)
} {
Error.SubsumptionCheckFailure(inferred, declared, left, right, _)
}
}
}

def inferForAll[A: HasRegion](tpes: NonEmptyList[(Type.Var.Bound, Kind)], expr: Expr[A]): Infer[TypedExpr[A]] =
for {
Expand Down Expand Up @@ -1104,6 +1130,18 @@ object Infer {
}
}

def validateSubs(list: List[(Type.Var.Bound, (Kind, Type))], left: Region, right: Region): Infer[Unit] =
list.parTraverse_ { case (boundVar, (kind, tpe)) =>
kindOf(tpe, right).flatMap { k =>
if (Kind.leftSubsumesRight(kind, k)) {
unit
}
else {
fail(Error.KindMismatch(Type.TyVar(boundVar), kind, tpe, k, left, right))
}
}
}

def checkApply[A: HasRegion](fn: Expr[A], args: NonEmptyList[Expr[A]], tag: A, tpe: Type, tpeRegion: Region): Infer[TypedExpr[A]] = {
val infOpt = maybeSimple(fn).flatTraverse { inferFnExpr =>
inferFnExpr.map { fnTe =>
Expand Down Expand Up @@ -1133,17 +1171,7 @@ object Infer {
infOpt.flatMap {
case Some((fnTe, inT, frees, inst)) =>
val regTe = region(tag)
val validKinds: Infer[Unit] =
inst.toList.parTraverse_ { case (boundVar, (kind, tpe)) =>
kindOf(tpe, regTe).flatMap { k =>
if (Kind.leftSubsumesRight(kind, k)) {
unit
}
else {
fail(Error.KindMismatch(Type.TyVar(boundVar), kind, tpe, k, regTe, tpeRegion))
}
}
}
val validKinds: Infer[Unit] = validateSubs(inst.toList, region(fn), regTe)
val instNoKind = inst.iterator
.map { case (k, (_, t)) => (k, t) }
.toMap[Type.Var, Type]
Expand Down Expand Up @@ -1399,19 +1427,8 @@ object Infer {
// substitute
// see if substitute rho with subs <:< expected
// else set inferred value
val regTe = region(te)
val regTag = region(tag)
val validKinds: Infer[Unit] =
subs.toList.parTraverse_ { case (boundVar, (kind, tpe)) =>
kindOf(tpe, regTe).flatMap { k =>
if (Kind.leftSubsumesRight(kind, k)) {
unit
}
else {
fail(Error.KindMismatch(Type.TyVar(boundVar), kind, tpe, k, regTe, regTag))
}
}
}
validateSubs(subs.toList, region(term), region(tag))

validKinds.parProductR(expect match {
case Expected.Check((r1, reg1)) =>
Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,8 @@ object Type {
case TyVar(toB: Var.Bound) =>
state.rightFrees.get(toB) match {
case Some(toBKind) =>
// TODO we could substitute to a compatible kind if
// we track it and return it correctly
if (kind === toBKind) {
Some(state.updated(b, (kind, Free(toB))))
if (Kind.leftSubsumesRight(kind, toBKind)) {
Some(state.updated(b, (toBKind, Free(toB))))
}
else None
case None => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ class RankNInferTest extends AnyFunSuite {
assert_:<:("List[forall a. a -> Int]", "List[(forall a. a) -> Int]")

assertTypesUnify("forall f: +* -> *. f[forall a. a]", "forall a. forall f: +* -> *. f[a]")
assertTypesDisjoint("forall f: * -> *. f[forall a. a]", "forall a. forall f: * -> *. f[a]")
assert_:<:("forall f: * -> *. f[Int]", "forall f: +* -> *. f[Int]")
assert_:<:("forall f: * -> *. f[Int]", "forall f: -* -> *. f[Int]")
assert_:<:("forall f: +* -> *. f[Int]", "forall f: 👻* -> *. f[Int]")
assert_:<:("forall f: -* -> *. f[Int]", "forall f: 👻* -> *. f[Int]")
assert_:<:("forall a. forall f: * -> *. f[a]", "forall f: * -> *. f[forall a. a]")
assert_:<:("forall a. forall f: -* -> *. f[a]", "forall f: -* -> *. f[forall a. a]")

assertTypesUnify("(forall a. a) -> Int", "(forall a. a) -> Int")
Expand Down

0 comments on commit 5ee9cfd

Please sign in to comment.