diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 67e1885b511f..dca2fbeb0dea 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1744,7 +1744,7 @@ object desugar { def adaptPatternArgs(elems: List[Tree], pt: Type)(using Context): List[Tree] = def reorderedNamedArgs(wildcardSpan: Span): List[untpd.Tree] = - var selNames = pt.namedTupleElementTypes.map(_(0)) + var selNames = pt.namedTupleElementTypes(false).map(_(0)) if selNames.isEmpty && pt.classSymbol.is(CaseClass) then selNames = pt.classSymbol.caseAccessors.map(_.name.asTermName) val nameToIdx = selNames.zipWithIndex.toMap diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 2890bdf306be..dd20c2db9192 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1337,10 +1337,25 @@ class Definitions { object NamedTuple: def apply(nmes: Type, vals: Type)(using Context): Type = AppliedType(NamedTupleTypeRef, nmes :: vals :: Nil) - def unapply(t: Type)(using Context): Option[(Type, Type)] = t match - case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol => - Some((nmes, vals)) - case _ => None + def unapply(t: Type)(using Context): Option[(Type, Type)] = + t match + case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol => + Some((nmes, vals)) + case tp: TypeProxy => + val t = unapply(tp.superType); t + case tp: OrType => + (unapply(tp.tp1), unapply(tp.tp2)) match + case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName => + Some(lhsName, lhsVal | rhsVal) + case _ => None + case tp: AndType => + (unapply(tp.tp1), unapply(tp.tp2)) match + case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName => + Some(lhsName, lhsVal & rhsVal) + case (lhs, None) => lhs + case (None, rhs) => rhs + case _ => None + case _ => None final def isCompiletime_S(sym: Symbol)(using Context): Boolean = sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass diff --git a/compiler/src/dotty/tools/dotc/core/TypeUtils.scala b/compiler/src/dotty/tools/dotc/core/TypeUtils.scala index 0a219fa6ddfd..14ccf32c7787 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeUtils.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeUtils.scala @@ -127,8 +127,17 @@ class TypeUtils: case Some(types) => TypeOps.nestedPairs(types) case None => throw new AssertionError("not a tuple") - def namedTupleElementTypesUpTo(bound: Int, normalize: Boolean = true)(using Context): List[(TermName, Type)] = + def namedTupleElementTypesUpTo(bound: Int, derived: Boolean, normalize: Boolean = true)(using Context): List[(TermName, Type)] = (if normalize then self.normalized else self).dealias match + // for desugaring and printer, ignore derived types to avoid infinite recursion in NamedTuple.unapply + case AppliedType(tycon, nmes :: vals :: Nil) if !derived && tycon.typeSymbol == defn.NamedTupleTypeRef.symbol => + val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map: + case ConstantType(Constant(str: String)) => str.toTermName + case t => throw TypeError(em"Malformed NamedTuple: names must be string types, but $t was found.") + val values = vals.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil) + names.zip(values) + case t if !derived => Nil + // default cause, used for post-typing case defn.NamedTuple(nmes, vals) => val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map: case ConstantType(Constant(str: String)) => str.toTermName @@ -138,22 +147,13 @@ class TypeUtils: case t => Nil - def namedTupleElementTypes(using Context): List[(TermName, Type)] = - namedTupleElementTypesUpTo(Int.MaxValue) + def namedTupleElementTypes(derived: Boolean)(using Context): List[(TermName, Type)] = + namedTupleElementTypesUpTo(Int.MaxValue, derived) def isNamedTupleType(using Context): Boolean = self match case defn.NamedTuple(_, _) => true case _ => false - def derivesFromNamedTuple(using Context): Boolean = self match - case defn.NamedTuple(_, _) => true - case tp: MatchType => - tp.bound.derivesFromNamedTuple || tp.reduced.derivesFromNamedTuple - case tp: TypeProxy => tp.superType.derivesFromNamedTuple - case tp: AndType => tp.tp1.derivesFromNamedTuple || tp.tp2.derivesFromNamedTuple - case tp: OrType => tp.tp1.derivesFromNamedTuple && tp.tp2.derivesFromNamedTuple - case _ => false - /** Drop all named elements in tuple type */ def stripNamedTuple(using Context): Type = self.normalized.dealias match case defn.NamedTuple(_, vals) => diff --git a/compiler/src/dotty/tools/dotc/interactive/Completion.scala b/compiler/src/dotty/tools/dotc/interactive/Completion.scala index ff5716b227ca..333af6a26b3b 100644 --- a/compiler/src/dotty/tools/dotc/interactive/Completion.scala +++ b/compiler/src/dotty/tools/dotc/interactive/Completion.scala @@ -532,7 +532,7 @@ object Completion: def namedTupleCompletionsFromType(tpe: Type): CompletionMap = val freshCtx = ctx.fresh.setExploreTyperState() inContext(freshCtx): - tpe.namedTupleElementTypes + tpe.namedTupleElementTypes(true) .map { (name, tpe) => val symbol = newSymbol(owner = NoSymbol, name, EmptyFlags, tpe) val denot = SymDenotation(symbol, NoSymbol, name, EmptyFlags, tpe) @@ -543,7 +543,7 @@ object Completion: .groupByName val qualTpe = qual.typeOpt - if qualTpe.derivesFromNamedTuple then + if qualTpe.isNamedTupleType then namedTupleCompletionsFromType(qualTpe) else if qualTpe.derivesFrom(defn.SelectableClass) then val pre = if !TypeOps.isLegalPrefix(qualTpe) then Types.SkolemType(qualTpe) else qualTpe diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index d460cec75115..27ab73f0fe4d 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -248,8 +248,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { def appliedText(tp: Type): Text = tp match case tp @ AppliedType(tycon, args) => val namedElems = - try tp.namedTupleElementTypesUpTo(200, normalize = false) - catch case ex: TypeError => Nil + try tp.namedTupleElementTypesUpTo(200, false, normalize = false) + catch + case ex: TypeError => Nil if namedElems.nonEmpty then toTextNamedTuple(namedElems) else tp.tupleElementTypesUpTo(200, normalize = false) match diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index c2864093ff70..164df6aae5b8 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -110,7 +110,7 @@ object Applications { } def namedTupleOrProductTypes(tp: Type)(using Context): List[Type] = - if tp.isNamedTupleType then tp.namedTupleElementTypes.map(_(1)) + if tp.isNamedTupleType then tp.namedTupleElementTypes(true).map(_(1)) else productSelectorTypes(tp, NoSourcePosition) def productSelectorTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = { diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 193cc443b4ae..9d273ebca866 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -876,7 +876,7 @@ trait Implicits: || inferView(dummyTreeOfType(from), to) (using ctx.fresh.addMode(Mode.ImplicitExploration).setExploreTyperState()).isSuccess // TODO: investigate why we can't TyperState#test here - || from.widen.derivesFromNamedTuple && to.derivesFrom(defn.TupleClass) + || from.widen.isNamedTupleType && to.derivesFrom(defn.TupleClass) && from.widen.stripNamedTuple <:< to ) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 76b853c4aabd..9b7e4fe36668 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -799,7 +799,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // Otherwise, try to expand a named tuple selection def tryNamedTupleSelection() = - val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes + val namedTupleElems = qual.tpe.widenDealias.namedTupleElementTypes(true) val nameIdx = namedTupleElems.indexWhere(_._1 == selName) if nameIdx >= 0 && Feature.enabled(Feature.namedTuples) then typed( @@ -875,7 +875,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer then val pre = if !TypeOps.isLegalPrefix(qual.tpe) then SkolemType(qual.tpe) else qual.tpe val fieldsType = pre.select(tpnme.Fields).widenDealias.simplified - val fields = fieldsType.namedTupleElementTypes + val fields = fieldsType.namedTupleElementTypes(true) typr.println(i"try dyn select $qual, $selName, $fields") fields.find(_._1 == selName) match case Some((_, fieldType)) => @@ -4663,7 +4663,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _: SelectionProto => tree // adaptations for selections are handled in typedSelect case _ if ctx.mode.is(Mode.ImplicitsEnabled) && tree.tpe.isValueType => - if tree.tpe.derivesFromNamedTuple && pt.derivesFrom(defn.TupleClass) then + if tree.tpe.isNamedTupleType && pt.derivesFrom(defn.TupleClass) then readapt(typed(untpd.Select(untpd.TypedSplice(tree), nme.toTuple))) else if pt.isRef(defn.AnyValClass, skipRefined = false) || pt.isRef(defn.ObjectClass, skipRefined = false) diff --git a/tests/run/i22150.check b/tests/run/i22150.check new file mode 100644 index 000000000000..4539bbf2d22d --- /dev/null +++ b/tests/run/i22150.check @@ -0,0 +1,3 @@ +0 +1 +2 diff --git a/tests/run/i22150.scala b/tests/run/i22150.scala new file mode 100644 index 000000000000..80c2222a98e7 --- /dev/null +++ b/tests/run/i22150.scala @@ -0,0 +1,26 @@ +//> using options -experimental -language:experimental.namedTuples +import language.experimental.namedTuples + +val directionsNT = IArray( + (dx = 0, dy = 1), // up + (dx = 1, dy = 0), // right + (dx = 0, dy = -1), // down + (dx = -1, dy = 0), // left +) +val IArray(UpNT @ _, _, _, _) = directionsNT + +object NT: + def foo[T <: (x: Int, y: String)](tup: T): Int = + tup.x + + def union[T](tup: (x: Int, y: String) | (x: Int, y: String)): Int = + tup.x + + def intersect[T](tup: (x: Int, y: String) & T): Int = + tup.x + + +@main def Test = + println(UpNT.dx) + println(NT.union((1, "a"))) + println(NT.intersect((2, "b")))