@@ -346,17 +346,19 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
346346 case t =>
347347 foldOver(c, t)
348348
349- def checkParams (refsToCheck : Refs , descr : => String ) =
349+ def checkRefs (refsToCheck : Refs , descr : => String ) =
350350 val badParams = mutable.ListBuffer [Symbol ]()
351351 def currentOwner = kind.dclSym.orElse(ctx.owner)
352- for hiddenRef <- prune(refsToCheck.footprint) do
353- val refSym = hiddenRef.termSymbol
354- if refSym.is(TermParam )
355- && ! refSym.hasAnnotation(defn.ConsumeAnnot )
356- && ! refSym.info.derivesFrom(defn.Caps_SharedCapability )
357- && currentOwner.isContainedIn(refSym.owner)
358- then
359- badParams += refSym
352+ for hiddenRef <- prune(refsToCheck) do
353+ val refSym = hiddenRef.pathRoot.termSymbol // TODO also hangle ThisTypes as pathRoots
354+ if refSym.exists && ! refSym.info.derivesFrom(defn.Caps_SharedCapability ) then
355+ if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.owner.enclosingMethodOrClass) then
356+ report.error(em """ Separation failure: $descr non-local $refSym""" , pos)
357+ else if refSym.is(TermParam )
358+ && ! refSym.hasAnnotation(defn.ConsumeAnnot )
359+ && currentOwner.isContainedIn(refSym.owner)
360+ then
361+ badParams += refSym
360362 if badParams.nonEmpty then
361363 def paramsStr (params : List [Symbol ]): String = (params : @ unchecked) match
362364 case p :: Nil => i " ${p.name}"
@@ -368,25 +370,28 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
368370 |The parameter $pluralS need $singleS to be annotated with @consume to allow this. """ ,
369371 pos)
370372
371- def checkParameters () = kind match
373+ def checkLegalRefs () = kind match
372374 case TypeKind .Result (sym, _) =>
373375 if ! sym.isAnonymousFunction // we don't check return types of anonymous functions
374376 && ! sym.is(Case ) // We don't check so far binders in patterns since they
375377 // have inferred universal types. TODO come back to this;
376378 // either infer more precise types for such binders or
377379 // "see through them" when we look at hidden sets.
378- then checkParams(tpe.deepCaptureSet.elems.hidden, i " $typeDescr type $tpe hides " )
380+ then
381+ val refs = tpe.deepCaptureSet.elems
382+ val toCheck = refs.hidden.footprint -- refs.footprint
383+ checkRefs(toCheck, i " $typeDescr type $tpe hides " )
379384 case TypeKind .Argument (arg) =>
380385 if tpe.hasAnnotation(defn.ConsumeAnnot ) then
381386 val capts = captures(arg)
382387 def descr (verb : String ) = i " argument to @consume parameter with type ${arg.nuType} $verb"
383- checkParams (capts, descr(" refers to" ))
384- checkParams (capts.hidden, descr(" hides" ))
388+ checkRefs (capts.footprint , descr(" refers to" ))
389+ checkRefs (capts.hidden.footprint , descr(" hides" ))
385390
386391 if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
387392 traverse(Captures .None , tpe)
388393 traverse.toCheck.foreach(checkParts)
389- checkParameters ()
394+ checkLegalRefs ()
390395 end checkType
391396
392397 private def collectMethodTypes (tp : Type ): List [TermLambda ] = tp match
@@ -426,10 +431,12 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
426431 if argss.nestedExists(_.needsSepCheck) then
427432 checkApply(tree, argss.flatten, dependencies(tree, argss))
428433
434+ def isUnsafeAssumeSeparate (tree : Tree )(using Context ): Boolean = tree match
435+ case tree : Apply => tree.symbol == defn.Caps_unsafeAssumeSeparate
436+ case _ => false
437+
429438 def traverse (tree : Tree )(using Context ): Unit =
430- tree match
431- case tree : Apply if tree.symbol == defn.Caps_unsafeAssumeSeparate => return
432- case _ =>
439+ if isUnsafeAssumeSeparate(tree) then return
433440 checkUse(tree)
434441 tree match
435442 case tree : GenericApply =>
@@ -446,7 +453,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
446453 defsShadow = saved
447454 case tree : ValOrDefDef =>
448455 traverseChildren(tree)
449- if ! tree.symbol.isOneOf(TermParamOrAccessor ) then
456+ if ! tree.symbol.isOneOf(TermParamOrAccessor ) && ! isUnsafeAssumeSeparate(tree.rhs) then
450457 checkType(tree.tpt, tree.symbol)
451458 if previousDefs.nonEmpty then
452459 capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
@@ -460,5 +467,3 @@ end SepChecker
460467
461468
462469
463-
464-
0 commit comments