Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature GADT support for class type parameters #11222

Merged
merged 2 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ object Inferencing {
* approx, see gadt-approximation-interaction.scala).
*/
def apply(tp: Type): Type = tp.dealias match {
case tp @ TypeRef(qual, nme) if (qual eq NoPrefix)
&& variance != 0
case tp @ TypeRef(qual, nme) if variance != 0
&& ctx.gadt.contains(tp.symbol)
=>
val sym = tp.symbol
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,15 @@ class Typer extends Namer
val tparamss = paramss1.collect {
case untpd.TypeDefs(tparams) => tparams
}

// Register GADT constraint for class type parameters from outer to inner class definition. (Useful when nested classes exist.) But do not cross a function definition.
if sym.flags.is(Method) then
rhsCtx.setFreshGADTBounds
ctx.outer.outersIterator.takeWhile(!_.owner.is(Method))
.filter(ctx => ctx.owner.isClass && ctx.owner.typeParams.nonEmpty)
.toList.reverse
.foreach(ctx => rhsCtx.gadt.addToConstraint(ctx.owner.typeParams))

if tparamss.nonEmpty then
rhsCtx.setFreshGADTBounds
val tparamSyms = tparamss.flatten.map(_.symbol)
Expand Down
14 changes: 14 additions & 0 deletions tests/pos/class-gadt/basic.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object basic {
enum Expr[A] {
case IntExpr(value: Int) extends Expr[Int]
case Other[T](value: T) extends Expr[T]
}

class C[A] {
def eval(e: Expr[A]): A =
e match {
case Expr.IntExpr(i) => i + 2
case Expr.Other(v) => v
}
}
}
15 changes: 15 additions & 0 deletions tests/pos/class-gadt/inheritance.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
object inheritance{
enum SUB[-A, +B]:
case Refl[S]() extends SUB[S, S]

class A[T](val v: T) {
val foo1: T = v
}

class C[T](val v1: T) extends A[T](v1) {
def eval1(t: T, e: SUB[T, Int]): Int =
e match {
case SUB.Refl() => foo1
}
}
}
11 changes: 11 additions & 0 deletions tests/pos/class-gadt/member.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
object member{
enum SUB[-A, +B]:
case Refl[S]() extends SUB[S, S]

class C[T] {
def eval1(t: T, e: SUB[T, Int]): Int =
e match {
case SUB.Refl() => t + 2
}
}
}
59 changes: 59 additions & 0 deletions tests/pos/class-gadt/nestedClass.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Nested class for GADT constraining
object nestedClass{
enum Expr[A] {
case IntExpr(value: Int) extends Expr[Int]
case Other[T](value: T) extends Expr[T]
}

class Outer1[C] {
class Inner1 {
def eval(e: Expr[C]): C =
e match {
case Expr.IntExpr(i) => i + 2
case Expr.Other(v) => v
}
}
}

def foo2[C](): Unit =
class Outer2 {
class Inner2 {
def eval(e: Expr[C]): C =
e match {
case Expr.IntExpr(i) => i + 2
case Expr.Other(v) => v
}
}
}

class Outer3[C] {
def foo3(): Unit =
class Inner3 {
def eval(e: Expr[C]): C =
e match {
case Expr.IntExpr(i) => i + 2
case Expr.Other(v) => v
}
}
}

trait Outer4[C] {
class Inner4 {
def eval(e: Expr[C]): C =
e match {
case Expr.IntExpr(i) => i + 2
case Expr.Other(v) => v
}
}
}

class Outer5[C] {
object Inner5 {
def eval(e: Expr[C]): C =
e match {
case Expr.IntExpr(i) => i + 2
case Expr.Other(v) => v
}
}
}
}
77 changes: 77 additions & 0 deletions tests/pos/class-gadt/variance.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
object variance {
enum SUB[-A, +B]:
case Refl[S]() extends SUB[S, S]
// Covariant
class C1[+T](v: T){
def foo(ev: T SUB Int): Int =
ev match {
case SUB.Refl() => v
}
}

// Contravariant
class B2[-T](v: T){}

class C2[-T](v: T){
def foo(ev: Int SUB T): B2[T] =
ev match {
case SUB.Refl() => new B2(v)
}
}

// Variance with inheritance

// superclass covariant and subclass covariant

class A3[+T](v: T) {
val value = v
}

class C3[+T](v: T) extends A3[T](v){
def foo(ev: T SUB Int): Int =
ev match {
case SUB.Refl() => value
}
}


// superclass covariant and subclass invariant
class A4[+T](v: T) {
val value = v
}

class C4[T](v: T) extends A4[T](v){
def foo(ev: T SUB Int): Int =
ev match {
case SUB.Refl() => value
}
}

// superclass contravariant and subclass contravariant
class B5[-T](v: T){}

class A5[-T](v: T) {
val value = new B5(v)
}

class C5[-T](v: T) extends A5[T](v){
def foo(ev: Int SUB T): B5[T] =
ev match {
case SUB.Refl() => value
}
}

// superclass contravariant and subclass invariant
class B6[-T](v: T){}

class A6[-T](v: T) {
val value = new B6(v)
}

class C6[-T](v: T) extends A6[T](v){
def foo(ev: Int SUB T): B6[T] =
ev match {
case SUB.Refl() => value
}
}
}
File renamed without changes.
File renamed without changes.