diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index a2d2d2cf358c..9b7d2b90ed1a 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -713,3 +713,21 @@ extension (self: Type) case _ => self +/** An extractor for a contains argument */ +object ContainsImpl: + def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] = + tree.fun.tpe.widen match + case fntpe: PolyType if tree.fun.symbol == defn.Caps_containsImpl => + tree.args match + case csArg :: refArg :: Nil => Some((csArg, refArg)) + case _ => None + case _ => None + +/** An extractor for a contains parameter */ +object ContainsParam: + def unapply(sym: Symbol)(using Context): Option[(TypeRef, CaptureRef)] = + sym.info.dealias match + case AppliedType(tycon, (cs: TypeRef) :: (ref: CaptureRef) :: Nil) + if tycon.typeSymbol == defn.Caps_ContainsTrait + && cs.typeSymbol.isAbstractOrParamType => Some((cs, ref)) + case _ => None diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala b/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala index 6578da89bbf8..f00c6869cd80 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala @@ -116,8 +116,12 @@ trait CaptureRef extends TypeProxy, ValueType: case x1: SingletonCaptureRef => x1.subsumes(y) case _ => false case x: TermParamRef => subsumesExistentially(x, y) + case x: TypeRef => assumedContainsOf(x).contains(y) case _ => false + def assumedContainsOf(x: TypeRef)(using Context): SimpleIdentitySet[CaptureRef] = + CaptureSet.assumedContains.getOrElse(x, SimpleIdentitySet.empty) + end CaptureRef trait SingletonCaptureRef extends SingletonType, CaptureRef diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 1d09b9dc5f20..25d8e0bc6506 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -16,7 +16,7 @@ import util.{SimpleIdentitySet, Property} import typer.ErrorReporting.Addenda import TypeComparer.subsumesExistentially import util.common.alwaysTrue -import scala.collection.mutable +import scala.collection.{mutable, immutable} import CCState.* /** A class for capture sets. Capture sets can be constants or variables. @@ -1125,6 +1125,12 @@ object CaptureSet: foldOver(cs, t) collect(CaptureSet.empty, tp) + type AssumedContains = immutable.Map[TypeRef, SimpleIdentitySet[CaptureRef]] + val AssumedContains: Property.Key[AssumedContains] = Property.Key() + + def assumedContains(using Context): AssumedContains = + ctx.property(AssumedContains).getOrElse(immutable.Map.empty) + private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key() /** Perform `op`. Under -Ycc-debug, collect and print info about all variables reachable diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index dbf01915122d..51cf362ca667 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -676,29 +676,24 @@ class CheckCaptures extends Recheck, SymTransformer: i"Sealed type variable $pname", "be instantiated to", i"This is often caused by a local capability$where\nleaking as part of its result.", tree.srcPos) - val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt))) - if meth == defn.Caps_containsImpl then checkContains(tree) - res + try handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt))) + finally checkContains(tree) end recheckTypeApply /** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked * capability and assert that `{r} <:CS`. */ - def checkContains(tree: TypeApply)(using Context): Unit = - tree.fun.knownType.widen match - case fntpe: PolyType => - tree.args match - case csArg :: refArg :: Nil => - val cs = csArg.knownType.captureSet - val ref = refArg.knownType - capt.println(i"check contains $cs , $ref") - ref match - case ref: CaptureRef if ref.isTracked => - checkElem(ref, cs, tree.srcPos) - case _ => - report.error(em"$refArg is not a tracked capability", refArg.srcPos) - case _ => - case _ => + def checkContains(tree: TypeApply)(using Context): Unit = tree match + case ContainsImpl(csArg, refArg) => + val cs = csArg.knownType.captureSet + val ref = refArg.knownType + capt.println(i"check contains $cs , $ref") + ref match + case ref: CaptureRef if ref.isTracked => + checkElem(ref, cs, tree.srcPos) + case _ => + report.error(em"$refArg is not a tracked capability", refArg.srcPos) + case _ => override def recheckBlock(tree: Block, pt: Type)(using Context): Type = inNestedLevel(super.recheckBlock(tree, pt)) @@ -814,15 +809,26 @@ class CheckCaptures extends Recheck, SymTransformer: val localSet = capturedVars(sym) if !localSet.isAlwaysEmpty then curEnv = Env(sym, EnvKind.Regular, localSet, curEnv) + + // ctx with AssumedContains entries for each Contains parameter + val bodyCtx = + var ac = CaptureSet.assumedContains + for paramSyms <- sym.paramSymss do + for case ContainsParam(cs, ref) <- paramSyms do + ac = ac.updated(cs, ac.getOrElse(cs, SimpleIdentitySet.empty) + ref) + if ac.isEmpty then ctx + else ctx.withProperty(CaptureSet.AssumedContains, Some(ac)) + inNestedLevel: // TODO: needed here? - try checkInferredResult(super.recheckDefDef(tree, sym), tree) + try checkInferredResult(super.recheckDefDef(tree, sym)(using bodyCtx), tree) finally if !sym.isAnonymousFunction then // Anonymous functions propagate their type to the enclosing environment // so it is not in general sound to interpolate their types. interpolateVarsIn(tree.tpt) curEnv = saved - + end recheckDefDef + /** If val or def definition with inferred (result) type is visible * in other compilation units, check that the actual inferred type * conforms to the expected type where all inferred capture sets are dropped. diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 1d2f2b05feb4..8981aa4aa6ac 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1002,7 +1002,7 @@ class Definitions { @tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox") @tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox") @tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg") - @tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability") + @tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains") @tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl") @tu lazy val PureClass: Symbol = requiredClass("scala.Pure") diff --git a/tests/pos-custom-args/captures/i21313.scala b/tests/pos-custom-args/captures/i21313.scala index 2fda6c0c0e45..b388b6487cb5 100644 --- a/tests/pos-custom-args/captures/i21313.scala +++ b/tests/pos-custom-args/captures/i21313.scala @@ -1,7 +1,16 @@ import caps.CapSet trait Async: - def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T + def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T = + val x: Async^{this} = ??? + val y: Async^{Cap^} = x + val ac: Async^ = ??? + def f(using caps.Contains[Cap, ac.type]) = + val x2: Async^{this} = ??? + val y2: Async^{Cap^} = x2 + val x3: Async^{ac} = ??? + val y3: Async^{Cap^} = x3 + ??? trait Source[+T, Cap^]: final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap. diff --git a/tests/run/Providers.check b/tests/run/Providers.check index 7b0a9a8b143e..a72c2c1e6fb7 100644 --- a/tests/run/Providers.check +++ b/tests/run/Providers.check @@ -18,3 +18,11 @@ Executing query: insert into subscribers(name, email) values Daniel daniel@Rockt You've just been subscribed to RockTheJVM. Welcome, Martin Acquired connection Executing query: insert into subscribers(name, email) values Martin odersky@gmail.com + +Injected2 +You've just been subscribed to RockTheJVM. Welcome, Daniel +Acquired connection +Executing query: insert into subscribers(name, email) values Daniel daniel@RocktheJVM.com +You've just been subscribed to RockTheJVM. Welcome, Martin +Acquired connection +Executing query: insert into subscribers(name, email) values Martin odersky@gmail.com diff --git a/tests/run/Providers.scala b/tests/run/Providers.scala index 3eb4b2df2207..8c5bf20bc02e 100644 --- a/tests/run/Providers.scala +++ b/tests/run/Providers.scala @@ -65,6 +65,8 @@ end Providers Explicit().test() println(s"\nInjected") Injected().test() + println(s"\nInjected2") + Injected2().test() /** Demonstrator for explicit dependency construction */ class Explicit: @@ -173,5 +175,55 @@ class Injected: end explicit end Injected +/** Injected with builders in companion objects */ +class Injected2: + import Providers.* + + case class User(name: String, email: String) + + class UserSubscription(emailService: EmailService, db: UserDatabase): + def subscribe(user: User) = + emailService.email(user) + db.insert(user) + object UserSubscription: + def apply()(using Provider[(EmailService, UserDatabase)]): UserSubscription = + new UserSubscription(provided[EmailService], provided[UserDatabase]) + + class EmailService: + def email(user: User) = + println(s"You've just been subscribed to RockTheJVM. Welcome, ${user.name}") + + class UserDatabase(pool: ConnectionPool): + def insert(user: User) = + pool.get().runQuery(s"insert into subscribers(name, email) values ${user.name} ${user.email}") + object UserDatabase: + def apply()(using Provider[(ConnectionPool)]): UserDatabase = + new UserDatabase(provided[ConnectionPool]) + + class ConnectionPool(n: Int): + def get(): Connection = + println(s"Acquired connection") + Connection() + + class Connection(): + def runQuery(query: String): Unit = + println(s"Executing query: $query") + + def test() = + given Provider[EmailService] = provide(EmailService()) + given Provider[ConnectionPool] = provide(ConnectionPool(10)) + given Provider[UserDatabase] = provide(UserDatabase()) + given Provider[UserSubscription] = provide(UserSubscription()) + + def subscribe(user: User)(using Provider[UserSubscription]) = + val sub = UserSubscription() + sub.subscribe(user) + + subscribe(User("Daniel", "daniel@RocktheJVM.com")) + subscribe(User("Martin", "odersky@gmail.com")) + end test +end Injected2 + +