Skip to content

Add dependent function types #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 27, 2017
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,15 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

override def Inlined(tree: Tree)(call: Tree, bindings: List[MemberDef], expansion: Tree)(implicit ctx: Context): Inlined = {
val tree1 = untpd.cpy.Inlined(tree)(call, bindings, expansion)
tree match {
case tree: Inlined if sameTypes(bindings, tree.bindings) && (expansion.tpe eq tree.expansion.tpe) =>
tree1.withTypeUnchecked(tree.tpe)
case _ => ta.assignType(tree1, bindings, expansion)
}
}

override def SeqLiteral(tree: Tree)(elems: List[Tree], elemtpt: Tree)(implicit ctx: Context): SeqLiteral = {
val tree1 = untpd.cpy.SeqLiteral(tree)(elems, elemtpt)
tree match {
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class GenAlias(pat: Tree, expr: Tree) extends Tree
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree]) extends TypTree
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree) extends DefTree
case class DependentTypeTree(tp: List[Symbol] => Type) extends Tree

@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY) with WithoutTypeOrPos[Untyped] {
override def isEmpty = true
Expand Down
13 changes: 10 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,8 @@ class Definitions {
val tsym = ft.typeSymbol
if (isFunctionClass(tsym)) {
val targs = ft.dealias.argInfos
Some(targs.init, targs.last, tsym.name.isImplicitFunction)
if (targs.isEmpty) None
else Some(targs.init, targs.last, tsym.name.isImplicitFunction)
}
else None
}
Expand Down Expand Up @@ -914,13 +915,19 @@ class Definitions {
def isProductSubType(tp: Type)(implicit ctx: Context) =
tp.derivesFrom(ProductType.symbol)

/** Is `tp` (an alias) of either a scala.FunctionN or a scala.ImplicitFunctionN? */
def isFunctionType(tp: Type)(implicit ctx: Context) = {
/** Is `tp` (an alias) of either a scala.FunctionN or a scala.ImplicitFunctionN
* instance?
*/
def isNonDepFunctionType(tp: Type)(implicit ctx: Context) = {
val arity = functionArity(tp)
val sym = tp.dealias.typeSymbol
arity >= 0 && isFunctionClass(sym) && tp.isRef(FunctionType(arity, sym.name.isImplicitFunction).typeSymbol)
}

/** Is `tp` a representation of a (possibly depenent) function type or an alias of such? */
def isFunctionType(tp: Type)(implicit ctx: Context) =
isNonDepFunctionType(tp.dropDependentRefinement)

// Specialized type parameters defined for scala.Function{0,1,2}.
private lazy val Function1SpecializedParams: collection.Set[Type] =
Set(IntType, LongType, FloatType, DoubleType)
Expand Down
6 changes: 0 additions & 6 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ object Mode {
/** Allow GADTFlexType labelled types to have their bounds adjusted */
val GADTflexible = newMode(8, "GADTflexible")

/** Allow dependent functions. This is currently necessary for unpickling, because
* some dependent functions are passed through from the front end(s?), even though they
* are technically speaking illegal.
*/
val AllowDependentFunctions = newMode(9, "AllowDependentFunctions")

/** We are currently printing something: avoid to produce more logs about
* the printing
*/
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ trait Symbols { this: Context =>
newClassSymbol(owner, name, flags, completer, privateWithin, coord, assocFile)
}

def newRefinedClassSymbol = newCompleteClassSymbol(
ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil)

/** Create a module symbol with associated module class
* from its non-info fields and a function producing the info
* of the module class (this info may be lazy).
Expand Down
45 changes: 36 additions & 9 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,12 @@ object Types {
case _ => this
}

/** Dealias, and if result is a dependent function type, drop the `apply` refinement. */
final def dropDependentRefinement(implicit ctx: Context): Type = dealias match {
case RefinedType(parent, nme.apply, _) => parent
case tp => tp
}

/** The type constructor of an applied type, otherwise the type itself */
final def typeConstructor(implicit ctx: Context): Type = this match {
case AppliedType(tycon, _) => tycon
Expand Down Expand Up @@ -1312,15 +1318,18 @@ object Types {
// ----- misc -----------------------------------------------------------

/** Turn type into a function type.
* @pre this is a non-dependent method type.
* @pre this is a method type without parameter dependencies.
* @param dropLast The number of trailing parameters that should be dropped
* when forming the function type.
*/
def toFunctionType(dropLast: Int = 0)(implicit ctx: Context): Type = this match {
case mt: MethodType if !mt.isDependent || ctx.mode.is(Mode.AllowDependentFunctions) =>
case mt: MethodType if !mt.isParamDependent =>
val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast
defn.FunctionOf(
formals1 mapConserve (_.underlyingIfRepeated(mt.isJavaMethod)), mt.resultType, mt.isImplicitMethod && !ctx.erasedTypes)
val funType = defn.FunctionOf(
formals1 mapConserve (_.underlyingIfRepeated(mt.isJavaMethod)),
mt.nonDependentResultApprox, mt.isImplicitMethod && !ctx.erasedTypes)
if (mt.isDependent) RefinedType(funType, nme.apply, mt)
else funType
}

/** The signature of this type. This is by default NotAMethod,
Expand Down Expand Up @@ -2581,7 +2590,7 @@ object Types {
def integrate(tparams: List[ParamInfo], tp: Type)(implicit ctx: Context): Type =
tparams match {
case LambdaParam(lam, _) :: _ => tp.subst(lam, this)
case tparams: List[Symbol @unchecked] => tp.subst(tparams, paramRefs)
case params: List[Symbol @unchecked] => tp.subst(params, paramRefs)
}

final def derivedLambdaType(paramNames: List[ThisName] = this.paramNames,
Expand Down Expand Up @@ -2688,7 +2697,7 @@ object Types {
* def f(x: C)(y: x.S) // dependencyStatus = TrueDeps
* def f(x: C)(y: x.T) // dependencyStatus = FalseDeps, i.e.
* // dependency can be eliminated by dealiasing.
*/
*/
private def dependencyStatus(implicit ctx: Context): DependencyStatus = {
if (myDependencyStatus != Unknown) myDependencyStatus
else {
Expand Down Expand Up @@ -2723,6 +2732,20 @@ object Types {
def isParamDependent(implicit ctx: Context): Boolean = paramDependencyStatus == TrueDeps

def newParamRef(n: Int) = new TermParamRef(this, n) {}

/** The least supertype of `resultType` that does not contain parameter dependencies */
def nonDependentResultApprox(implicit ctx: Context): Type =
if (isDependent) {
val dropDependencies = new ApproximatingTypeMap {
def apply(tp: Type) = tp match {
case tp @ TermParamRef(thisLambdaType, _) =>
range(tp.bottomType, atVariance(1)(apply(tp.underlying)))
case _ => mapOver(tp)
}
}
dropDependencies(resultType)
}
else resultType
}

abstract case class MethodType(paramNames: List[TermName])(
Expand Down Expand Up @@ -3197,8 +3220,10 @@ object Types {
case _ => false
}

protected def kindString: String

override def toString =
try s"ParamRef($paramName)"
try s"${kindString}ParamRef($paramName)"
catch {
case ex: IndexOutOfBoundsException => s"ParamRef(<bad index: $paramNum>)"
}
Expand All @@ -3207,8 +3232,9 @@ object Types {
/** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to
* refer to `TermParamRef(binder, paramNum)`.
*/
abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef {
abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef with SingletonType {
type BT = TermLambda
def kindString = "Term"
def copyBoundType(bt: BT) = bt.paramRefs(paramNum)
}

Expand All @@ -3217,6 +3243,7 @@ object Types {
*/
abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int) extends ParamRef {
type BT = TypeLambda
def kindString = "Type"
def copyBoundType(bt: BT) = bt.paramRefs(paramNum)

/** Looking only at the structure of `bound`, is one of the following true?
Expand Down Expand Up @@ -3731,7 +3758,7 @@ object Types {
// println(s"absMems: ${absMems map (_.show) mkString ", "}")
if (absMems.size == 1)
absMems.head.info match {
case mt: MethodType if !mt.isDependent => Some(absMems.head)
case mt: MethodType if !mt.isParamDependent => Some(absMems.head)
case _ => None
}
else if (tp isRef defn.PartialFunctionClass)
Expand Down
7 changes: 3 additions & 4 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TreeUnpickler(reader: TastyReader, nameAtRef: NameRef => TermName, posUnpi
/** The unpickled trees */
def unpickle()(implicit ctx: Context): List[Tree] = {
assert(roots != null, "unpickle without previous enterTopLevel")
new TreeReader(reader).readTopLevel()(ctx.addMode(Mode.AllowDependentFunctions))
new TreeReader(reader).readTopLevel()
}

class Completer(owner: Symbol, reader: TastyReader) extends LazyType {
Expand Down Expand Up @@ -999,8 +999,7 @@ class TreeUnpickler(reader: TastyReader, nameAtRef: NameRef => TermName, posUnpi
val argPats = until(end)(readTerm())
UnApply(fn, implicitArgs, argPats, patType)
case REFINEDtpt =>
val refineCls = ctx.newCompleteClassSymbol(
ctx.owner, tpnme.REFINE_CLASS, NonMember, parents = Nil)
val refineCls = ctx.newRefinedClassSymbol
typeAtAddr(start) = refineCls.typeRef
val parent = readTpt()
val refinements = readStats(refineCls, end)(localContext(refineCls))
Expand Down Expand Up @@ -1096,7 +1095,7 @@ class TreeUnpickler(reader: TastyReader, nameAtRef: NameRef => TermName, posUnpi
class LazyReader[T <: AnyRef](reader: TreeReader, op: TreeReader => Context => T) extends Trees.Lazy[T] {
def complete(implicit ctx: Context): T = {
pickling.println(i"starting to read at ${reader.reader.currentAddr}")
op(reader)(ctx.addMode(Mode.AllowDependentFunctions).withPhaseNoLater(ctx.picklerPhase))
op(reader)(ctx.withPhaseNoLater(ctx.picklerPhase))
}
}

Expand Down
30 changes: 28 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ object Parsers {
* | InfixType
* FunArgTypes ::= InfixType
* | `(' [ FunArgType {`,' FunArgType } ] `)'
* | '(' TypedFunParam {',' TypedFunParam } ')'
*/
def typ(): Tree = {
val start = in.offset
Expand All @@ -745,6 +746,16 @@ object Parsers {
val t = typ()
if (isImplicit) new ImplicitFunction(params, t) else Function(params, t)
}
def funArgTypesRest(first: Tree, following: () => Tree) = {
val buf = new ListBuffer[Tree] += first
while (in.token == COMMA) {
in.nextToken()
buf += following()
}
buf.toList
}
var isValParamList = false

val t =
if (in.token == LPAREN) {
in.nextToken()
Expand All @@ -754,10 +765,19 @@ object Parsers {
}
else {
openParens.change(LPAREN, 1)
val ts = commaSeparated(funArgType)
val paramStart = in.offset
val ts = funArgType() match {
case Ident(name) if name != tpnme.WILDCARD && in.token == COLON =>
isValParamList = true
funArgTypesRest(
typedFunParam(paramStart, name.toTermName),
() => typedFunParam(in.offset, ident()))
case t =>
funArgTypesRest(t, funArgType)
}
openParens.change(LPAREN, -1)
accept(RPAREN)
if (isImplicit || in.token == ARROW) functionRest(ts)
if (isImplicit || isValParamList || in.token == ARROW) functionRest(ts)
else {
for (t <- ts)
if (t.isInstanceOf[ByNameTypeTree])
Expand Down Expand Up @@ -790,6 +810,12 @@ object Parsers {
}
}

/** TypedFunParam ::= id ':' Type */
def typedFunParam(start: Offset, name: TermName): Tree = atPos(start) {
accept(COLON)
makeParameter(name, typ(), Modifiers(Param))
}

/** InfixType ::= RefinedType {id [nl] refinedType}
*/
def infixType(): Tree = infixTypeRest(refinedType())
Expand Down
26 changes: 14 additions & 12 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,14 @@ class PlainPrinter(_ctx: Context) extends Printer {
toTextRef(tp) ~ ".type"
case tp: TermRef if tp.denot.isOverloaded =>
"<overloaded " ~ toTextRef(tp) ~ ">"
case tp: SingletonType =>
toTextLocal(tp.underlying) ~ "(" ~ toTextRef(tp) ~ ")"
case tp: TypeRef =>
toTextPrefix(tp.prefix) ~ selectionString(tp)
case tp: TermParamRef =>
ParamRefNameString(tp) ~ ".type"
case tp: TypeParamRef =>
ParamRefNameString(tp) ~ lambdaHash(tp.binder)
case tp: SingletonType =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something is wrong with the way dependent function types are printed:

scala> val depfun1: DF = (x: C) => x.m
val depfun1: (C => ){apply: (x: C): } = <function1>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now.

toTextLocal(tp.underlying) ~ "(" ~ toTextRef(tp) ~ ")"
case AppliedType(tycon, args) =>
(toTextLocal(tycon) ~ "[" ~ Text(args map argText, ", ") ~ "]").close
case tp: RefinedType =>
Expand Down Expand Up @@ -180,26 +184,19 @@ class PlainPrinter(_ctx: Context) extends Printer {
case NoPrefix =>
"<noprefix>"
case tp: MethodType =>
def paramText(name: TermName, tp: Type) = toText(name) ~ ": " ~ toText(tp)
changePrec(GlobalPrec) {
(if (tp.isImplicitMethod) "(implicit " else "(") ~
Text((tp.paramNames, tp.paramInfos).zipped map paramText, ", ") ~
(if (tp.isImplicitMethod) "(implicit " else "(") ~ paramsText(tp) ~
(if (tp.resultType.isInstanceOf[MethodType]) ")" else "): ") ~
toText(tp.resultType)
}
case tp: ExprType =>
changePrec(GlobalPrec) { "=> " ~ toText(tp.resultType) }
case tp: TypeLambda =>
def paramText(name: Name, bounds: TypeBounds): Text = name.unexpandedName.toString ~ toText(bounds)
changePrec(GlobalPrec) {
"[" ~ Text((tp.paramNames, tp.paramInfos).zipped.map(paramText), ", ") ~
"]" ~ lambdaHash(tp) ~ (" => " provided !tp.resultType.isInstanceOf[MethodType]) ~
"[" ~ paramsText(tp) ~ "]" ~ lambdaHash(tp) ~
(" => " provided !tp.resultType.isInstanceOf[MethodType]) ~
toTextGlobal(tp.resultType)
}
case tp: TypeParamRef =>
ParamRefNameString(tp) ~ lambdaHash(tp.binder)
case tp: TermParamRef =>
ParamRefNameString(tp) ~ ".type"
case AnnotatedType(tpe, annot) =>
toTextLocal(tpe) ~ " " ~ toText(annot)
case tp: TypeVar =>
Expand All @@ -221,6 +218,11 @@ class PlainPrinter(_ctx: Context) extends Printer {
}
}.close

protected def paramsText(tp: LambdaType): Text = {
def paramText(name: Name, tp: Type) = toText(name) ~ toTextRHS(tp)
Text((tp.paramNames, tp.paramInfos).zipped.map(paramText), ", ")
}

protected def ParamRefNameString(name: Name): String = name.toString

protected def ParamRefNameString(param: ParamRef): String =
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
override def toText(tp: Type): Text = controlled {
def toTextTuple(args: List[Type]): Text =
"(" ~ Text(args.map(argText), ", ") ~ ")"

def toTextFunction(args: List[Type], isImplicit: Boolean): Text =
changePrec(GlobalPrec) {
val argStr: Text =
Expand All @@ -126,6 +127,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
("implicit " provided isImplicit) ~ argStr ~ " => " ~ argText(args.last)
}

def toTextDependentFunction(appType: MethodType): Text = {
("implicit " provided appType.isImplicitMethod) ~
"(" ~ paramsText(appType) ~ ") => " ~ toText(appType.resultType)
}

def isInfixType(tp: Type): Boolean = tp match {
case AppliedType(tycon, args) =>
args.length == 2 &&
Expand Down Expand Up @@ -158,6 +164,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
if (isInfixType(tp)) return toTextInfixType(tycon, args)
case EtaExpansion(tycon) =>
return toText(tycon)
case tp: RefinedType if defn.isFunctionType(tp) =>
return toTextDependentFunction(tp.refinedInfo.asInstanceOf[MethodType])
case tp: TypeRef =>
if (tp.symbol.isAnonymousClass && !ctx.settings.uniqid.value)
return toText(tp.info)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Dynamic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ trait Dynamic { self: Typer with Applications =>

tree.tpe.widen match {
case tpe: MethodType =>
if (tpe.isDependent)
fail(i"has a dependent method type")
if (tpe.isParamDependent)
fail(i"has a method type with inter-parameter dependencies")
else if (tpe.paramNames.length > Definitions.MaxStructuralMethodArity)
fail(i"""takes too many parameters.
|Structural types only support methods taking up to ${Definitions.MaxStructuralMethodArity} arguments""")
Expand Down
Loading