@@ -28,6 +28,16 @@ object SepChecker:
2828 else NeedsCheck
2929 end Captures
3030
31+ /** The kind of checked type, used for composing error messages */
32+ enum TypeKind :
33+ case Result (sym : Symbol , inferred : Boolean )
34+ case Argument
35+
36+ def dclSym = this match
37+ case Result (sym, _) => sym
38+ case _ => NoSymbol
39+ end TypeKind
40+
3141class SepChecker (checker : CheckCaptures .CheckerAPI ) extends tpd.TreeTraverser :
3242 import tpd .*
3343 import checker .*
@@ -204,7 +214,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
204214 for (arg, idx) <- indexedArgs do
205215 if arg.needsSepCheck then
206216 val ac = formalCaptures(arg)
207- checkType(arg.formalType, arg.srcPos, NoSymbol , " the argument's adapted type " )
217+ checkType(arg.formalType, arg.srcPos, TypeKind . Argument )
208218 val hiddenInArg = ac.hidden.footprint
209219 // println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
210220 val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
@@ -232,18 +242,29 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
232242 sepUseError(tree, usedFootprint, overlap)
233243
234244 def checkType (tpt : Tree , sym : Symbol )(using Context ): Unit =
235- checkType(tpt.nuType, tpt.srcPos, sym, " " )
236-
237- /** Check that all parts of type `tpe` are separated.
238- * @param tpe the type to check
239- * @param pos position for error reporting
240- * @param sym if `tpe` is the (result-) type of a val or def, the symbol of
241- * this definition, otherwise NoSymbol. If `sym` exists we
242- * deduct its associated direct and reach capabilities everywhere
243- * from the capture sets we check.
244- * @param what a string describing what kind of type it is
245- */
246- def checkType (tpe : Type , pos : SrcPos , sym : Symbol , what : String )(using Context ): Unit =
245+ checkType(tpt.nuType, tpt.srcPos,
246+ TypeKind .Result (sym, inferred = tpt.isInstanceOf [InferredTypeTree ]))
247+
248+ /** Check that all parts of type `tpe` are separated. */
249+ def checkType (tpe : Type , pos : SrcPos , kind : TypeKind )(using Context ): Unit =
250+
251+ def typeDescr = kind match
252+ case TypeKind .Result (sym, inferred) =>
253+ def inferredStr = if inferred then " inferred" else " "
254+ def resultStr = if sym.info.isInstanceOf [MethodicType ] then " result" else " "
255+ i " $sym's $inferredStr$resultStr"
256+ case TypeKind .Argument =>
257+ " the argument's adapted type"
258+
259+ def explicitRefs (tp : Type ): Refs = tp match
260+ case tp : (TermRef | ThisType ) => SimpleIdentitySet (tp)
261+ case AnnotatedType (parent, _) => explicitRefs(parent)
262+ case AndType (tp1, tp2) => explicitRefs(tp1) ++ explicitRefs(tp2)
263+ case OrType (tp1, tp2) => explicitRefs(tp1) ** explicitRefs(tp2)
264+ case _ => emptySet
265+
266+ def prune (refs : Refs ): Refs =
267+ refs.deductSym(kind.dclSym) -- explicitRefs(tpe)
247268
248269 def checkParts (parts : List [Type ]): Unit =
249270 var footprint : Refs = emptySet
@@ -265,21 +286,21 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
265286 if ! globalOverlap.isEmpty then
266287 val (prevStr, prevRefs, overlap) = parts.iterator.take(checked)
267288 .map: prev =>
268- val prevRefs = mapRefs(prev.deepCaptureSet.elems).footprint.deductSym(sym )
289+ val prevRefs = prune( mapRefs(prev.deepCaptureSet.elems).footprint)
269290 (i " , $prev , " , prevRefs, prevRefs.overlapWith(next))
270291 .dropWhile(_._3.isEmpty)
271292 .nextOption
272293 .getOrElse((" " , current, globalOverlap))
273294 report.error(
274- em """ Separation failure in $what type $tpe.
295+ em """ Separation failure in $typeDescr type $tpe.
275296 |One part, $part , $nextRel ${CaptureSet (next)}.
276297 |A previous part $prevStr $prevRel ${CaptureSet (prevRefs)}.
277298 |The two sets overlap at ${CaptureSet (overlap)}. """ ,
278299 pos)
279300
280301 val partRefs = part.deepCaptureSet.elems
281- val partFootprint = partRefs.footprint.deductSym(sym )
282- val partHidden = partRefs.hidden.footprint.deductSym(sym ) -- partFootprint
302+ val partFootprint = prune( partRefs.footprint)
303+ val partHidden = prune( partRefs.hidden.footprint) -- partFootprint
283304
284305 checkSep(footprint, partHidden, identity, " references" , " hides" )
285306 checkSep(hiddenSet, partHidden, _.hidden, " also hides" , " hides" )
@@ -325,9 +346,43 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
325346 case t =>
326347 foldOver(c, t)
327348
349+ def checkParameters () =
350+ val badParams = mutable.ListBuffer [Symbol ]()
351+ def currentOwner = kind.dclSym.orElse(ctx.owner)
352+ for hiddenRef <- prune(tpe.deepCaptureSet.elems.hidden.footprint) do
353+ val refSym = hiddenRef.termSymbol
354+ if refSym.is(TermParam )
355+ && ! refSym.hasAnnotation(defn.ConsumeAnnot )
356+ && ! refSym.info.derivesFrom(defn.Caps_SharedCapability )
357+ && currentOwner.isContainedIn(refSym.owner)
358+ then
359+ badParams += refSym
360+ if badParams.nonEmpty then
361+ def paramsStr (params : List [Symbol ]): String = (params : @ unchecked) match
362+ case p :: Nil => i " ${p.name}"
363+ case p :: p2 :: Nil => i " ${p.name} and ${p2.name}"
364+ case p :: ps => i " ${p.name}, ${paramsStr(ps)}"
365+ val (pluralS, singleS) = if badParams.tail.isEmpty then (" " , " s" ) else (" s" , " " )
366+ report.error(
367+ em """ Separation failure: $typeDescr type $tpe hides parameter $pluralS ${paramsStr(badParams.toList)}
368+ |The parameter $pluralS need $singleS to be annotated with @consume to allow this. """ ,
369+ pos)
370+
371+ def flagHiddenParams =
372+ kind match
373+ case TypeKind .Result (sym, _) =>
374+ ! sym.isAnonymousFunction // we don't check return types of anonymous functions
375+ && ! sym.is(Case ) // We don't check so far binders in patterns since they
376+ // have inferred universal types. TODO come back to this;
377+ // either infer more precise types for such binders or
378+ // "see through them" when we look at hidden sets.
379+ case TypeKind .Argument =>
380+ false
381+
328382 if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
329383 traverse(Captures .None , tpe)
330384 traverse.toCheck.foreach(checkParts)
385+ if flagHiddenParams then checkParameters()
331386 end checkType
332387
333388 private def collectMethodTypes (tp : Type ): List [TermLambda ] = tp match
0 commit comments