Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for cleanup retains scheme #21350

Merged
merged 5 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ object Trees {
override def isEmpty: Boolean = !hasType
override def toString: String =
s"TypeTree${if (hasType) s"[$typeOpt]" else ""}"
def isInferred = false
}

/** Tree that replaces a level 1 splices in pickled (level 0) quotes.
Expand All @@ -800,6 +801,7 @@ object Trees {
*/
class InferredTypeTree[+T <: Untyped](implicit @constructorOnly src: SourceFile) extends TypeTree[T]:
type ThisTree[+T <: Untyped] <: InferredTypeTree[T]
override def isInferred = true

/** ref.type */
case class SingletonTypeTree[+T <: Untyped] private[ast] (ref: Tree[T])(implicit @constructorOnly src: SourceFile)
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ extension (tp: Type)
def boxed(using Context): Type = tp.dealias match
case tp @ CapturingType(parent, refs) if !tp.isBoxed && !refs.isAlwaysEmpty =>
tp.annot match
case ann: CaptureAnnotation => AnnotatedType(parent, ann.boxedAnnot)
case ann: CaptureAnnotation =>
assert(!parent.derivesFrom(defn.Caps_CapSet))
AnnotatedType(parent, ann.boxedAnnot)
case ann => tp
case tp: RealTypeBounds =>
tp.derivedTypeBounds(tp.lo.boxed, tp.hi.boxed)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/cc/CapturingType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object CapturingType:
* boxing status is the same or if A is boxed.
*/
def apply(parent: Type, refs: CaptureSet, boxed: Boolean = false)(using Context): Type =
assert(!boxed || !parent.derivesFrom(defn.Caps_CapSet))
if refs.isAlwaysEmpty && !refs.keepAlways then parent
else parent match
case parent @ CapturingType(parent1, refs1) if boxed || !parent.isBoxed =>
Expand Down
54 changes: 39 additions & 15 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -388,23 +388,25 @@ class CheckCaptures extends Recheck, SymTransformer:
// should be included.
val included = cs.filter: c =>
c.stripReach match
case ref: TermRef =>
//if c.isReach then println(i"REACH $c in ${env.owner}")
//assert(!env.owner.isAnonymousFunction)
case ref: NamedType =>
val refSym = ref.symbol
val refOwner = refSym.owner
val isVisible = isVisibleFromEnv(refOwner)
if !isVisible && c.isReach && refSym.is(Param) && refOwner == env.owner then
if refSym.hasAnnotation(defn.UnboxAnnot) then
capt.println(i"exempt: $ref in $refOwner")
else
// Reach capabilities that go out of scope have to be approximated
// by their underlying capture set, which cannot be universal.
// Reach capabilities of @unboxed parameters are exempted.
val cs = CaptureSet.ofInfo(c)
cs.disallowRootCapability: () =>
report.error(em"Local reach capability $c leaks into capture scope of ${env.ownerString}", pos)
checkSubset(cs, env.captured, pos, provenance(env))
if !isVisible
&& (c.isReach || ref.isType)
&& refSym.is(Param)
&& refOwner == env.owner
then
if refSym.hasAnnotation(defn.UnboxAnnot) then
capt.println(i"exempt: $ref in $refOwner")
else
// Reach capabilities that go out of scope have to be approximated
// by their underlying capture set, which cannot be universal.
// Reach capabilities of @unboxed parameters are exempted.
val cs = CaptureSet.ofInfo(c)
cs.disallowRootCapability: () =>
report.error(em"Local reach capability $c leaks into capture scope of ${env.ownerString}", pos)
checkSubset(cs, env.captured, pos, provenance(env))
isVisible
case ref: ThisType => isVisibleFromEnv(ref.cls)
case _ => false
Expand Down Expand Up @@ -674,7 +676,29 @@ class CheckCaptures extends Recheck, SymTransformer:
i"Sealed type variable $pname", "be instantiated to",
i"This is often caused by a local capability$where\nleaking as part of its result.",
tree.srcPos)
handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
if meth == defn.Caps_containsImpl then checkContains(tree)
res
end recheckTypeApply

/** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked
* capability and assert that `{r} <:CS`.
*/
def checkContains(tree: TypeApply)(using Context): Unit =
tree.fun.knownType.widen match
case fntpe: PolyType =>
tree.args match
case csArg :: refArg :: Nil =>
val cs = csArg.knownType.captureSet
val ref = refArg.knownType
capt.println(i"check contains $cs , $ref")
ref match
case ref: CaptureRef if ref.isTracked =>
checkElem(ref, cs, tree.srcPos)
case _ =>
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
case _ =>
case _ =>

override def recheckBlock(tree: Block, pt: Type)(using Context): Type =
inNestedLevel(super.recheckBlock(tree, pt))
Expand Down
21 changes: 11 additions & 10 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
private def box(tp: Type)(using Context): Type =
def recur(tp: Type): Type = tp.dealiasKeepAnnotsAndOpaques match
case tp @ CapturingType(parent, refs) =>
if tp.isBoxed then tp else tp.boxed
if tp.isBoxed || parent.derivesFrom(defn.Caps_CapSet) then tp
else tp.boxed
case tp @ AnnotatedType(parent, ann) =>
if ann.symbol.isRetains
if ann.symbol.isRetains && !parent.derivesFrom(defn.Caps_CapSet)
then CapturingType(parent, ann.tree.toCaptureSet, boxed = true)
else tp.derivedAnnotatedType(box(parent), ann)
case tp1 @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp1) =>
Expand Down Expand Up @@ -329,10 +330,10 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
end transformExplicitType

/** Transform type of type tree, and remember the transformed type as the type the tree */
private def transformTT(tree: TypeTree, boxed: Boolean, exact: Boolean)(using Context): Unit =
private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit =
if !tree.hasRememberedType then
val transformed =
if tree.isInstanceOf[InferredTypeTree] && !exact
if tree.isInferred
then transformInferredType(tree.tpe)
else transformExplicitType(tree.tpe, tptToCheck = Some(tree))
tree.rememberType(if boxed then box(transformed) else transformed)
Expand Down Expand Up @@ -397,8 +398,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
&& !ccConfig.useSealed
&& !sym.hasAnnotation(defn.UncheckedCapturesAnnot),
// types of mutable variables are boxed in pre 3.3 code
exact = sym.allOverriddenSymbols.hasNext,
// types of symbols that override a parent don't get a capture set TODO drop
)
catch case ex: IllegalCaptureRef =>
capt.println(i"fail while transforming result type $tpt of $sym")
Expand Down Expand Up @@ -441,7 +440,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
traverse(fn)
if !defn.isTypeTestOrCast(fn.symbol) then
for case arg: TypeTree <- args do
transformTT(arg, boxed = true, exact = false) // type arguments in type applications are boxed
transformTT(arg, boxed = true) // type arguments in type applications are boxed

case tree: TypeDef if tree.symbol.isClass =>
val sym = tree.symbol
Expand All @@ -464,7 +463,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:

def postProcess(tree: Tree)(using Context): Unit = tree match
case tree: TypeTree =>
transformTT(tree, boxed = false, exact = false)
transformTT(tree, boxed = false)
case tree: ValOrDefDef =>
val sym = tree.symbol

Expand Down Expand Up @@ -605,8 +604,10 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
!refs.isEmpty
case tp: (TypeRef | AppliedType) =>
val sym = tp.typeSymbol
if sym.isClass then !sym.isPureClass
else instanceCanBeImpure(tp.superType)
if sym.isClass
then !sym.isPureClass
else !tp.derivesFrom(defn.Caps_CapSet) // CapSet arguments don't get other capture set variables added
&& instanceCanBeImpure(tp.superType)
case tp: (RefinedOrRecType | MatchType) =>
instanceCanBeImpure(tp.underlying)
case tp: AndType =>
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -993,15 +993,17 @@ class Definitions {
@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap")
@tu lazy val Caps_Capability: TypeSymbol = CapsModule.requiredType("Capability")
@tu lazy val Caps_CapSet = requiredClass("scala.caps.CapSet")
@tu lazy val Caps_CapSet: ClassSymbol = requiredClass("scala.caps.CapSet")
@tu lazy val Caps_reachCapability: TermSymbol = CapsModule.requiredMethod("reachCapability")
@tu lazy val Caps_capsOf: TermSymbol = CapsModule.requiredMethod("capsOf")
@tu lazy val Caps_Exists = requiredClass("scala.caps.Exists")
@tu lazy val Caps_Exists: ClassSymbol = requiredClass("scala.caps.Exists")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")
@tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox")
@tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability")
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")

@tu lazy val PureClass: Symbol = requiredClass("scala.Pure")

Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
report.error(ex.toMessage, tree.srcPos.focus)
pickleErrorType()
case ex: AssertionError =>
println(i"error when pickling tree $tree")
println(i"error when pickling tree $tree of class ${tree.getClass}")
throw ex
case ex: MatchError =>
println(i"error when pickling tree $tree")
println(i"error when pickling tree $tree of class ${tree.getClass}")
throw ex
}
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
else keywordText("{{") ~ keywordText("/* inlined from ") ~ toText(call) ~ keywordText(" */") ~ bodyText ~ keywordText("}}")
case tpt: untpd.DerivedTypeTree =>
"<derived typetree watching " ~ tpt.watched.showSummary() ~ ">"
case TypeTree() =>
case tree: TypeTree =>
typeText(toText(tree.typeOpt))
~ Str("(inf)").provided(tree.isInstanceOf[InferredTypeTree] && printDebug)
~ Str("(inf)").provided(tree.isInferred && printDebug)
case SingletonTypeTree(ref) =>
toTextLocal(ref) ~ "." ~ keywordStr("type")
case RefinedTypeTree(tpt, refines) =>
Expand Down
39 changes: 19 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,20 +303,19 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
if !tree.symbol.is(Package) then tree
else errorTree(tree, em"${tree.symbol} cannot be used as a type")

// Cleans up retains annotations in inferred type trees. This is needed because
// during the typer, it is infeasible to correctly infer the capture sets in most
// cases, resulting ill-formed capture sets that could crash the pickler later on.
// See #20035.
private def cleanupRetainsAnnot(symbol: Symbol, tpt: Tree)(using Context): Tree =
/** Make result types of ValDefs and DefDefs that override some other definitions
* declared types rather than InferredTypes. This is necessary since we otherwise
* clean retains annotations from such types. But for an overriding symbol the
* retains annotations come from the explicitly declared parent types, so should
* be kept.
*/
private def makeOverrideTypeDeclared(symbol: Symbol, tpt: Tree)(using Context): Tree =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand correctly, an overridden member:

  1. is typed normally like any definition during Typer,
  2. if its type tree is an InferredTypeTree, it get switched to a TypeTree here during PostTyper,
  3. gets its typed checked against the overloaded definition only later during RefChecks.

Seems to make sense ✅

tpt match
case tpt: InferredTypeTree
if !symbol.allOverriddenSymbols.hasNext =>
// if there are overridden symbols, the annotation comes from an explicit type of the overridden symbol
// and should be retained.
val tm = new CleanupRetains
val tpe1 = tm(tpt.tpe)
tpt.withType(tpe1)
case _ => tpt
if symbol.allOverriddenSymbols.hasNext =>
TypeTree(tpt.tpe, inferred = false).withSpan(tpt.span).withAttachmentsFrom(tpt)
case _ =>
tpt

override def transform(tree: Tree)(using Context): Tree =
try tree match {
Expand Down Expand Up @@ -432,7 +431,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
val tree1 = cpy.ValDef(tree)(tpt = cleanupRetainsAnnot(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
val tree1 = cpy.ValDef(tree)(tpt = makeOverrideTypeDeclared(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
checkStableSelection(tree.rhs)
processValOrDefDef(super.transform(tree1))
Expand All @@ -441,7 +440,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
annotateContextResults(tree)
val tree1 = cpy.DefDef(tree)(tpt = cleanupRetainsAnnot(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
val tree1 = cpy.DefDef(tree)(tpt = makeOverrideTypeDeclared(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
case tree: TypeDef =>
registerIfHasMacroAnnotations(tree)
Expand Down Expand Up @@ -524,12 +523,12 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
report.error(em"type ${alias.tpe} outside bounds $bounds", tree.srcPos)
super.transform(tree)
case tree: TypeTree =>
tree.withType(
tree.tpe match {
case AnnotatedType(tpe, annot) => AnnotatedType(tpe, transformAnnot(annot))
case tpe => tpe
}
)
val tpe = if tree.isInferred then CleanupRetains()(tree.tpe) else tree.tpe
tree.withType:
tpe match
case AnnotatedType(parent, annot) =>
AnnotatedType(parent, transformAnnot(annot)) // TODO: Also map annotations embedded in type?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also map annotations embedded in type?

Does that only transform outermost annotations? Shouldn't this be a TypeMap?

case _ => tpe
case Typed(Ident(nme.WILDCARD), _) =>
withMode(Mode.Pattern)(super.transform(tree))
// The added mode signals that bounds in a pattern need not
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ trait Applications extends Compatibility {
def makeVarArg(n: Int, elemFormal: Type): Unit = {
val args = typedArgBuf.takeRight(n).toList
typedArgBuf.dropRightInPlace(n)
val elemtpt = TypeTree(elemFormal)
val elemtpt = TypeTree(elemFormal, inferred = true)
typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt))
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ trait TypeAssigner {
else sym.info

private def toRepeated(tree: Tree, from: ClassSymbol)(using Context): Tree =
Typed(tree, TypeTree(tree.tpe.widen.translateToRepeated(from)))
Typed(tree, TypeTree(tree.tpe.widen.translateToRepeated(from), inferred = true))

def seqToRepeated(tree: Tree)(using Context): Tree = toRepeated(tree, defn.SeqClass)

Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
cpy.Block(block)(stats, expr1) withType expr1.tpe // no assignType here because avoid is redundant
case _ =>
val target = pt.simplified
val targetTpt = InferredTypeTree().withType(target)
val targetTpt = TypeTree(target, inferred = true)
if tree.tpe <:< target then Typed(tree, targetTpt)
else
// This case should not normally arise. It currently does arise in test cases
Expand Down Expand Up @@ -2092,7 +2092,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// TODO: move the check above to patternMatcher phase
val uncheckedTpe = AnnotatedType(sel.tpe.widen, Annotation(defn.UncheckedAnnot, tree.selector.span))
tpd.cpy.Match(result)(
selector = tpd.Typed(sel, new tpd.InferredTypeTree().withType(uncheckedTpe)),
selector = tpd.Typed(sel, tpd.TypeTree(uncheckedTpe, inferred = true)),
cases = result.cases
)
case _ =>
Expand Down
15 changes: 14 additions & 1 deletion library/src/scala/caps.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scala

import annotation.{experimental, compileTimeOnly}
import annotation.{experimental, compileTimeOnly, retainsCap}

@experimental object caps:

Expand All @@ -19,6 +19,19 @@ import annotation.{experimental, compileTimeOnly}
/** Carrier trait for capture set type parameters */
trait CapSet extends Any

/** A type constraint expressing that the capture set `C` needs to contain
* the capability `R`
*/
sealed trait Contains[C <: CapSet @retainsCap, R <: Singleton]

/** The only implementation of `Contains`. The constraint that `{R} <: C` is
* added separately by the capture checker.
*/
given containsImpl[C <: CapSet @retainsCap, R <: Singleton]: Contains[C, R]()

/** A wrapper indicating a type variable in a capture argument list of a
* @retains annotation. E.g. `^{x, Y^}` is represented as `@retains(x, capsOf[Y])`.
*/
@compileTimeOnly("Should be be used only internally by the Scala compiler")
def capsOf[CS]: Any = ???

Expand Down
11 changes: 11 additions & 0 deletions tests/neg-custom-args/captures/i21313.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Error: tests/neg-custom-args/captures/i21313.scala:6:27 -------------------------------------------------------------
6 |def foo(x: Async) = x.await(???) // error
| ^
| (x : Async) is not a tracked capability
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i21313.scala:15:12 ---------------------------------------
15 | ac1.await(src2) // error
| ^^^^
| Found: (src2 : Source[Int, caps.CapSet^{ac2}]^?)
| Required: Source[Int, caps.CapSet^{ac1}]^
|
| longer explanation available when compiling with `-explain`
15 changes: 15 additions & 0 deletions tests/neg-custom-args/captures/i21313.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import caps.CapSet

trait Async:
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T

def foo(x: Async) = x.await(???) // error

trait Source[+T, Cap^]:
final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.

def test(using ac1: Async^, ac2: Async^, x: String) =
val src1 = new Source[Int, CapSet^{ac1}] {}
ac1.await(src1) // ok
val src2 = new Source[Int, CapSet^{ac2}] {}
ac1.await(src2) // error
Loading
Loading