Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match capability class as CapturingType #16851

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ extension (tp: Type)
* The identity for all other types.
*/
def boxed(using Context): Type = tp.dealias match
case tp @ CapturingType(parent, refs) if !tp.isBoxed && !refs.isAlwaysEmpty =>
case tp @ CapturingType.Annotated(parent, refs) if !tp.isBoxed && !refs.isAlwaysEmpty =>
tp.annot match
case ann: CaptureAnnotation =>
ann.boxedType(tp)
Expand All @@ -79,6 +79,8 @@ extension (tp: Type)
case None => ann.tree.putAttachment(BoxedType, BoxedTypeCache())
case _ =>
ann.tree.attachment(BoxedType)(tp)
case CapturingType.Capability(parent, cs) =>
CapturingType(parent, cs).boxed
case tp: RealTypeBounds =>
tp.derivedTypeBounds(tp.lo.boxed, tp.hi.boxed)
case _ =>
Expand Down Expand Up @@ -190,6 +192,12 @@ extension (tp: Type)
case _ =>
false

def isCapabilityBase(using Context): Boolean = tp match
case tp: NamedType =>
val sym = tp.classSymbol
sym.exists && sym.asClass.isCapabilityBase
case _ => false

extension (cls: ClassSymbol)

def pureBaseClass(using Context): Option[Symbol] =
Expand All @@ -200,6 +208,9 @@ extension (cls: ClassSymbol)
selfType.exists && selfType.captureSet.isAlwaysEmpty
})

def isCapabilityBase(using Context): Boolean =
cls.is(Flags.CapabilityBase)

extension (sym: Symbol)

/** A class is pure if:
Expand Down Expand Up @@ -250,6 +261,11 @@ extension (tp: AnnotatedType)
case ann: CaptureAnnotation => ann.boxed
case _ => false

extension (tp: Type)
def isBoxed(using Context): Boolean = tp match
case tp: AnnotatedType => tp.isBoxed
case _ => false

extension (ts: List[Type])
/** Equivalent to ts.mapconserve(_.boxedUnlessFun(tycon)) but more efficient where
* it is the identity.
Expand Down
67 changes: 63 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CapturingType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package dotc
package cc

import core.*
import Decorators.*
import Types.*, Symbols.*, Contexts.*
import NameKinds.UniqueName
import util.SimpleIdentityMap

/** A (possibly boxed) capturing type. This is internally represented as an annotated type with a @retains
* or @retainsByName annotation, but the extractor will succeed only at phase CheckCaptures.
Expand Down Expand Up @@ -40,12 +43,14 @@ object CapturingType:
/** An extractor that succeeds only during CheckCapturingPhase. Boxing statis is
* returned separately by CaptureOps.isBoxed.
*/
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet)] =
def unapply(tp: Type)(using Context): Option[(Type, CaptureSet)] =
if ctx.phase == Phases.checkCapturesPhase
&& tp.annot.symbol == defn.RetainsAnnot
&& !ctx.mode.is(Mode.IgnoreCaptures)
then
EventuallyCapturingType.unapply(tp)
tp match
case Annotated(parent, cs) => Some(parent, cs)
case Capability(parent, cs) => Some(parent, cs)
case _ => None
else None

/** Check whether a type is uncachable when computing `baseType`.
Expand All @@ -58,15 +63,69 @@ object CapturingType:
ctx.phase == Phases.checkCapturesPhase &&
(Setup.isDuringSetup || ctx.mode.is(Mode.IgnoreCaptures) && tp.isEventuallyCapturingType)

object Annotated:
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet)] =
if ctx.phase == Phases.checkCapturesPhase
&& !ctx.mode.is(Mode.IgnoreCaptures)
&& tp.annot.symbol == defn.RetainsAnnot
then
EventuallyCapturingType.unapplyAnnot(tp)
else None

object Capability:
def unapply(tp: Type)(using Context): Option[(Type, CaptureSet)] =
if ctx.phase == Phases.checkCapturesPhase
&& !ctx.mode.is(Mode.IgnoreCaptures)
then
EventuallyCapturingType.unapplyCap(tp)
else None

end CapturingType


/** An extractor for types that will be capturing types at phase CheckCaptures. Also
* included are types that indicate captures on enclosing call-by-name parameters
* before phase ElimByName.
*/
object EventuallyCapturingType:

def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet)] =
object Annotated:
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet)] = unapplyAnnot(tp)

object Capability:
def unapply(tp: TypeRef)(using Context): Option[(Type, CaptureSet)] = unapplyCap(tp)

private var pureCapClassSymCache: SimpleIdentityMap[ClassSymbol, ClassSymbol] = SimpleIdentityMap.empty

private def createPureSymbolOf(csym: ClassSymbol)(using Context): ClassSymbol =
csym.copy(flags = csym.flags | Flags.CapabilityBase).asClass

private def pureSymbolOf(csym: ClassSymbol)(using Context): ClassSymbol =
pureCapClassSymCache(csym) match
case psym: ClassSymbol => psym
case null =>
val sym = createPureSymbolOf(csym)
pureCapClassSymCache = pureCapClassSymCache.updated(csym, sym)
sym

def unapply(tp: Type)(using Context): Option[(Type, CaptureSet)] =
tp match
case tp: AnnotatedType => unapplyAnnot(tp)
case _ => unapplyCap(tp)

def unapplyCap(tp: Type)(using Context): Option[(Type, CaptureSet)] =
if tp.classSymbol.hasAnnotation(defn.CapabilityAnnot) && !tp.classSymbol.is(Flags.CapabilityBase) then
val sym = tp.classSymbol
val psym = pureSymbolOf(sym.asClass)
tp match
case tp: TypeRef => Some((psym.typeRef, CaptureSet.universal))
case tp: AppliedType =>
Some((tp.derivedAppliedType(psym.typeRef, tp.args), CaptureSet.universal))
case _ => None
else None
// None

def unapplyAnnot(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet)] =
val sym = tp.annot.symbol
if sym == defn.RetainsAnnot || sym == defn.RetainsByNameAnnot then
tp.annot match
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ class CheckCaptures extends Recheck, SymTransformer:
expected

/** Adapt `actual` type to `expected` type by inserting boxing and unboxing conversions
*
*
* @param alwaysConst always make capture set variables constant after adaptation
*/
def adaptBoxed(actual: Type, expected: Type, pos: SrcPos, alwaysConst: Boolean = false)(using Context): Type =
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import transform.Recheck.*
import CaptureSet.IdentityCaptRefMap
import Synthetics.isExcluded
import util.Property
import reporting.trace

/** A tree traverser that prepares a compilation unit to be capture checked.
* It does the following:
Expand Down Expand Up @@ -408,8 +409,12 @@ extends tpd.TreeTraverser:
traverse(tree.rhs)
case tree @ TypeApply(fn, args) =>
traverse(fn)
val isErasedValue = fn match
case Ident(tp) =>
fn.symbol eq defn.Compiletime_erasedValue
case _ => false
for case arg: TypeTree <- args do
transformTT(arg, boxed = true, exact = false) // type arguments in type applications are boxed
transformTT(arg, boxed = !isErasedValue, exact = false) // type arguments in type applications are boxed
case _ =>
traverseChildren(tree)
tree match
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ class Definitions {
*/
object ByNameFunction:
def apply(tp: Type)(using Context): Type = tp match
case tp @ EventuallyCapturingType(tp1, refs) if tp.annot.symbol == RetainsByNameAnnot =>
case tp @ EventuallyCapturingType.Annotated(tp1, refs) if tp.annot.symbol == RetainsByNameAnnot =>
CapturingType(apply(tp1), refs)
case _ =>
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ object Flags {
/** Symbol cannot be found as a member during typer */
val (Invisible @ _, _, _) = newFlags(45, "<invisible>")

/** Symbol represents the pure base class of a capability class */
val (CapabilityBase @ _, _, _) = newFlags(46, "<capability-base>")

// ------------ Flags following this one are not pickled ----------------------------------

/** Symbol is not a member of its owner */
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2238,7 +2238,7 @@ object SymDenotations {
case tp: TypeParamRef => // uncachable, since baseType depends on context bounds
recur(TypeComparer.bounds(tp).hi)

case CapturingType(parent, refs) =>
case CapturingType.Annotated(parent, refs) =>
tp.derivedCapturingType(recur(parent), refs)

case tp: TypeProxy =>
Expand Down
32 changes: 28 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import typer.ProtoTypes.constrained
import typer.Applications.productSelectorTypes
import reporting.trace
import annotation.constructorOnly
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing, isBoxedCapturing, boxed, boxedUnlessFun, boxedIfTypeParam, isAlwaysPure}
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing, isBoxedCapturing, boxed, boxedUnlessFun, boxedIfTypeParam, isAlwaysPure, isCapabilityBase}

/** Provides methods to compare types.
*/
Expand Down Expand Up @@ -302,7 +302,14 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
// For convenience we want X$ <:< X.type
// This is safe because X$ self-type is X.type
sym1 = sym1.companionModule
if ((sym1 ne NoSymbol) && (sym1 eq sym2))

def isMatchingSymbols =
(sym1 eq sym2) || ctx.phase == Phases.checkCapturesPhase && {
sym1.name == sym2.name &&
sym1.is(Flags.CapabilityBase) || sym2.is(Flags.CapabilityBase)
}

if ((sym1 ne NoSymbol) && isMatchingSymbols)
ctx.erasedTypes ||
sym1.isStaticOwner ||
isSubPrefix(tp1.prefix, tp2.prefix) ||
Expand Down Expand Up @@ -1137,7 +1144,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
/** Subtype test for the hk application `tp2 = tycon2[args2]`.
*/
def compareAppliedType2(tp2: AppliedType, tycon2: Type, args2: List[Type]): Boolean = {
val tparams = tycon2.typeParams

def getTypeParams: List[TypeParamInfo] =
if ctx.phase != Phases.checkCapturesPhase then
tycon2.typeParams
else tycon2.match
case tycon2: TypeRef if tycon2.isCapabilityBase =>
tycon2.symbol.info.typeParams
case _ =>
tycon2.typeParams

val tparams = getTypeParams
if (tparams.isEmpty) return false // can happen for ill-typed programs, e.g. neg/tcpoly_overloaded.scala

/** True if `tp1` and `tp2` have compatible type constructors and their
Expand Down Expand Up @@ -1215,8 +1232,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
&& ctx.gadt.contains(tycon2sym)
&& ctx.gadt.isLess(tycon1sym, tycon2sym)

def isMatchingSymbols =
(tycon1sym == tycon2sym) || {
ctx.phase == Phases.checkCapturesPhase &&
tycon1sym.name == tycon2sym.name &&
tycon1sym.is(Flags.CapabilityBase) || tycon2sym.is(Flags.CapabilityBase)
}

val res = (
tycon1sym == tycon2sym && isSubPrefix(tycon1.prefix, tycon2.prefix)
isMatchingSymbols && isSubPrefix(tycon1.prefix, tycon2.prefix)
|| tycon1sym.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2))
|| tycon2sym.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo))
|| byGadtOrdering
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ object TypeOps:
tp1 match {
case tp1: RecType =>
return tp1.rebind(approximateOr(tp1.parent, tp2))
case CapturingType(parent1, refs1) =>
case CapturingType.Annotated(parent1, refs1) =>
return tp1.derivedCapturingType(approximateOr(parent1, tp2), refs1)
case err: ErrorType =>
return err
Expand All @@ -305,7 +305,7 @@ object TypeOps:
tp2 match {
case tp2: RecType =>
return tp2.rebind(approximateOr(tp1, tp2.parent))
case CapturingType(parent2, refs2) =>
case CapturingType.Annotated(parent2, refs2) =>
return tp2.derivedCapturingType(approximateOr(tp1, parent2), refs2)
case err: ErrorType =>
return err
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4974,7 +4974,7 @@ object Types {
else if (clsd.is(Module)) givenSelf
else if (ctx.erasedTypes) appliedRef
else givenSelf match
case givenSelf @ EventuallyCapturingType(tp, _) =>
case givenSelf @ EventuallyCapturingType.Annotated(tp, _) =>
givenSelf.derivedAnnotatedType(tp & appliedRef, givenSelf.annot)
case _ =>
AndType(givenSelf, appliedRef)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
~ (Str(": ") provided !tp.resultType.isInstanceOf[MethodOrPoly])
~ toText(tp.resultType)
}
case ExprType(ct @ EventuallyCapturingType(parent, refs))
case ExprType(ct @ EventuallyCapturingType.Annotated(parent, refs))
if ct.annot.symbol == defn.RetainsByNameAnnot =>
if refs.isUniversal then changePrec(GlobalPrec) { "=> " ~ toText(parent) }
else toText(CapturingType(ExprType(parent), refs))
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,14 @@ abstract class Recheck extends Phase, SymTransformer:
assert(false, i"unexpected type of ${tree.fun}: $funtpe")

def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type =
recheck(tree.fun).widen match
val tp = recheck(tree.fun)
tp.widen match
case fntpe: PolyType =>
assert(fntpe.paramInfos.hasSameLengthAs(tree.args))
val argTypes = tree.args.map(recheck(_))
constFold(tree, fntpe.instantiate(argTypes))
case otherTp =>
assert(false, i"unexpected function type when checking $tree, $tp ~~> $otherTp")

def recheckTyped(tree: Typed)(using Context): Type =
val tptType = recheck(tree.tpt)
Expand Down
4 changes: 2 additions & 2 deletions compiler/test/dotty/tools/dotc/Playground.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import dotty.tools.vulpix._
import org.junit.Test
import org.junit.Ignore

@Ignore class Playground:
class Playground:
import TestConfiguration._
import CompilationTests._

@Test def example: Unit =
implicit val testGroup: TestGroup = TestGroup("playground")
compileFile("tests/playground/example.scala", defaultOptions).checkCompile()
compileFile("/Users/linyxus/Workspace/dotty/issues/real-try.scala", defaultOptions).checkCompile()
28 changes: 20 additions & 8 deletions tests/neg-custom-args/captures/curried-simplified.check
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,40 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:11:39 ---------------------------
11 | def y3: Cap -> Protect[Int -> Int] = x3 // error
| ^^
| Found: ? (x$0: Cap) -> {x$0} Int -> Int
| Required: Cap -> Protect[Int -> Int]
| Found: ? (x$0: {*} Cap) -> {x$0} Int -> Int
| Required: Cap² -> Protect[Int -> Int]
|
| where: Cap is a class
| Cap² is a class
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:15:33 ---------------------------
15 | def y5: Cap -> {} Int -> Int = x5 // error
| ^^
| Found: ? Cap -> {x} Int -> Int
| Required: Cap -> {} Int -> Int
| Found: ? ({*} Cap) -> {x} Int -> Int
| Required: Cap² -> {} Int -> Int
|
| where: Cap is a class
| Cap² is a class
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:17:49 ---------------------------
17 | def y6: Cap -> {} Cap -> Protect[Int -> Int] = x6 // error
| ^^
| Found: ? (x$0: Cap) -> {x$0} (x$0: Cap) -> {x$0, x$0} Int -> Int
| Required: Cap -> {} Cap -> Protect[Int -> Int]
| Found: ? (x$0: {*} Cap) -> {x$0} (x$0: {*} Cap) -> {x$0, x$0} Int -> Int
| Required: Cap² -> {} Cap² -> Protect[Int -> Int]
|
| where: Cap is a class
| Cap² is a class
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:19:49 ---------------------------
19 | def y7: Cap -> Protect[Cap -> {} Int -> Int] = x7 // error
| ^^
| Found: ? (x$0: Cap) -> {x$0} (x: Cap) -> {x$0, x} Int -> Int
| Required: Cap -> Protect[Cap -> {} Int -> Int]
| Found: ? (x$0: {*} Cap) -> {x$0} (x: {*} Cap) -> {x$0, x} Int -> Int
| Required: Cap² -> Protect[Cap² -> {} Int -> Int]
|
| where: Cap is a class
| Cap² is a class
|
| longer explanation available when compiling with `-explain`
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ trait WrappedProperties extends PropertiesTrait {
object WrappedProperties {
object AccessControl extends WrappedProperties {
def wrap[T](body: => T): Option[T] =
try Some(body)
try
val result: T = body
Some(result)
catch {
// the actual exception we are concerned with is AccessControlException,
// but that's deprecated on JDK 17, so catching its superclass is a convenient
Expand Down