Skip to content

Fix #4440: Do not serialize the content of static objects #5775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,9 @@ class Definitions {
case List(pt) => (pt isRef StringClass)
case _ => false
}).symbol.asTerm

lazy val JavaSerializableClass: ClassSymbol = ctx.requiredClass("java.io.Serializable")

lazy val ComparableClass: ClassSymbol = ctx.requiredClass("java.lang.Comparable")

lazy val SystemClass: ClassSymbol = ctx.requiredClass("java.lang.System")
Expand Down Expand Up @@ -656,6 +658,11 @@ class Definitions {
lazy val Product_productPrefixR: TermRef = ProductClass.requiredMethodRef(nme.productPrefix)
def Product_productPrefix(implicit ctx: Context): Symbol = Product_productPrefixR.symbol

lazy val ModuleSerializationProxyType: TypeRef = ctx.requiredClassRef("scala.runtime.ModuleSerializationProxy")
def ModuleSerializationProxyClass(implicit ctx: Context): ClassSymbol = ModuleSerializationProxyType.symbol.asClass
lazy val ModuleSerializationProxyConstructor: TermSymbol =
ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(WildcardType)))

lazy val GenericType: TypeRef = ctx.requiredClassRef("scala.reflect.Generic")
def GenericClass(implicit ctx: Context): ClassSymbol = GenericType.symbol.asClass
lazy val ShapeType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape")
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,6 @@ object StdNames {
val productIterator: N = "productIterator"
val productPrefix: N = "productPrefix"
val raw_ : N = "raw"
val readResolve: N = "readResolve"
val reflect: N = "reflect"
val reflectiveSelectable: N = "reflectiveSelectable"
val reify : N = "reify"
Expand Down Expand Up @@ -558,6 +557,7 @@ object StdNames {
val withFilterIfRefutable: N = "withFilterIfRefutable$"
val WorksheetWrapper: N = "WorksheetWrapper"
val wrap: N = "wrap"
val writeReplace: N = "writeReplace"
val zero: N = "zero"
val zip: N = "zip"
val nothingRuntimeClass: N = "scala.runtime.Nothing$"
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,10 @@ object SymDenotations {
*/
def derivesFrom(base: Symbol)(implicit ctx: Context): Boolean = false

/** Is this symbol a class that extends `java.io.Serializable` ? */
def isSerializable(implicit ctx: Context): Boolean =
isClass && derivesFrom(defn.JavaSerializableClass)

/** Is this symbol a class that extends `AnyVal`? */
final def isValueClass(implicit ctx: Context): Boolean = {
val di = initial
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ class PlainPrinter(_ctx: Context) extends Printer {

def toText(const: Constant): Text = const.tag match {
case StringTag => stringText("\"" + escapedString(const.value.toString) + "\"")
case ClazzTag => "classOf[" ~ toText(const.typeValue.classSymbol) ~ "]"
case ClazzTag => "classOf[" ~ toText(const.typeValue) ~ "]"
case CharTag => literalText(s"'${escapedChar(const.charValue)}'")
case LongTag => literalText(const.longValue.toString + "L")
case EnumTag => literalText(const.symbolValue.name.toString)
Expand Down
47 changes: 38 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import ValueClasses.isDerivedValueClass
* def productArity: Int
* def productPrefix: String
*
* Special handling:
* protected def readResolve(): AnyRef
* Add to serializable static objects, unless an implementation
* already exists:
* private def writeReplace(): AnyRef
*
* Selectively added to value classes, unless a non-default
* implementation already exists:
Expand All @@ -50,8 +51,10 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols }
def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }

/** The synthetic methods of the case or value class `clazz`. */
def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
/** If this is a case or value class, return the appropriate additional methods,
* otherwise return nothing.
*/
def caseAndValueMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
val clazzType = clazz.appliedRef
lazy val accessors =
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
Expand Down Expand Up @@ -255,12 +258,38 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
*/
def canEqualBody(that: Tree): Tree = that.isInstance(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot)))

symbolsToSynthesize flatMap syntheticDefIfMissing
symbolsToSynthesize.flatMap(syntheticDefIfMissing)
}

def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template =
if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner))
cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass))
/** If this is a serializable static object `Foo`, add the method:
*
* private def writeReplace(): AnyRef =
* new scala.runtime.ModuleSerializationProxy(classOf[Foo.type])
*
* unless an implementation already exists, otherwise do nothing.
*/
def serializableObjectMethod(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
def hasWriteReplace: Boolean =
clazz.membersNamed(nme.writeReplace)
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
.exists
if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) {
val writeReplace = ctx.newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
List(
DefDef(writeReplace,
_ => New(defn.ModuleSerializationProxyType,
defn.ModuleSerializationProxyConstructor,
List(Literal(Constant(clazz.sourceModule.termRef)))))
.withSpan(ctx.owner.span.focus))
}
else
impl
Nil
}

def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template = {
val clazz = ctx.owner.asClass
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body)
}

}
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
if (typedArgs.length <= pt.paramInfos.length && !isNamed)
if (typedFn.symbol == defn.Predef_classOf && typedArgs.nonEmpty) {
val arg = typedArgs.head
checkClassType(arg.tpe, arg.sourcePos, traitReq = false, stablePrefixReq = false)
if (!arg.symbol.is(Module)) // Allow `classOf[Foo.type]` if `Foo` is an object
checkClassType(arg.tpe, arg.sourcePos, traitReq = false, stablePrefixReq = false)
}
case _ =>
}
Expand Down
54 changes: 54 additions & 0 deletions library/src/scala/runtime/ModuleSerializationProxy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copied from https://github.com/scala/scala/blob/2.13.x/src/library/scala/runtime/ModuleSerializationProxy.java
// TODO: Remove this file once we switch to the Scala 2.13 stdlib since it already contains it.

/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala.runtime;

import java.io.Serializable;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashSet;
import java.util.Set;

/** A serialization proxy for singleton objects */
public final class ModuleSerializationProxy implements Serializable {
private static final long serialVersionUID = 1L;
private final Class<?> moduleClass;
private static final ClassValue<Object> instances = new ClassValue<Object>() {
@Override
protected Object computeValue(Class<?> type) {
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<Object>) () -> type.getField("MODULE$").get(null));
} catch (PrivilegedActionException e) {
return rethrowRuntime(e.getCause());
}
}
};

private static Object rethrowRuntime(Throwable e) {
Throwable cause = e.getCause();
if (cause instanceof RuntimeException) throw (RuntimeException) cause;
else throw new RuntimeException(cause);
}

public ModuleSerializationProxy(Class<?> moduleClass) {
this.moduleClass = moduleClass;
}

@SuppressWarnings("unused")
private Object readResolve() {
return instances.get(moduleClass);
}
}
6 changes: 3 additions & 3 deletions library/src/scala/tasty/reflect/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -643,11 +643,11 @@ trait Printers

def keepDefinition(d: Definition): Boolean = {
val flags = d.symbol.flags
def isCaseClassUnOverridableMethod: Boolean = {
def isUndecompilableCaseClassMethod: Boolean = {
// Currently the compiler does not allow overriding some of the methods generated for case classes
d.symbol.flags.is(Flags.Synthetic) &&
(d match {
case DefDef("apply" | "unapply", _, _, _, _) if d.symbol.owner.flags.is(Flags.Object) => true
case DefDef("apply" | "unapply" | "writeReplace", _, _, _, _) if d.symbol.owner.flags.is(Flags.Object) => true
case DefDef(n, _, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
n == "copy" ||
n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method
Expand All @@ -657,7 +657,7 @@ trait Printers
})
}
def isInnerModuleObject = d.symbol.flags.is(Flags.Lazy) && d.symbol.flags.is(Flags.Object)
!flags.is(Flags.Param) && !flags.is(Flags.ParamAccessor) && !flags.is(Flags.FieldAccessor) && !isCaseClassUnOverridableMethod && !isInnerModuleObject
!flags.is(Flags.Param) && !flags.is(Flags.ParamAccessor) && !flags.is(Flags.FieldAccessor) && !isUndecompilableCaseClassMethod && !isInnerModuleObject
}
val stats1 = stats.collect {
case IsDefinition(stat) if keepDefinition(stat) => stat
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/classOf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ object Test {

def f1[T] = classOf[T] // error
def f2[T <: String] = classOf[T] // error
val x = classOf[Test.type] // error
val x = classOf[Test.type] // ok
val y = classOf[C { type I = String }] // error
val z = classOf[A] // ok
}
4 changes: 4 additions & 0 deletions tests/run/classof-object.decompiled
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/** Decompiled from out/runTestFromTasty/run/classof-object/Test.tasty */
object Test {
def main(args: scala.Array[scala.Predef.String]): scala.Unit = if (scala.Predef.classOf[Test.type].==(Test.getClass()).unary_!) dotty.DottyPredef.assertFail() else ()
}
5 changes: 5 additions & 0 deletions tests/run/classof-object.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
object Test {
def main(args: Array[String]): Unit = {
assert(classOf[Test.type] == Test.getClass)
}
}
4 changes: 2 additions & 2 deletions tests/run/literals.decompiled
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
object Test {
def αρετη: java.lang.String = "alpha rho epsilon tau eta"
case class GGG(i: scala.Int) {
def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i)
override def hashCode(): scala.Int = {
var acc: scala.Int = 767242539
acc = scala.runtime.Statics.mix(acc, GGG.this.i)
Expand All @@ -24,6 +23,7 @@ object Test {
case _ =>
throw new java.lang.IndexOutOfBoundsException(n.toString())
}
def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i)
}
object GGG extends scala.Function1[scala.Int, Test.GGG]
def check_success[a](name: scala.Predef.String, closure: => a, expected: a): scala.Unit = {
Expand Down Expand Up @@ -95,4 +95,4 @@ object Test {
val ggg: scala.Int = Test.GGG.apply(1).αα(Test.GGG.apply(2))
Test.check_success[scala.Int]("ggg == 3", ggg, 3)
}
}
}
27 changes: 27 additions & 0 deletions tests/run/module-serialization-proxy-class-unload.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import java.io.File

object Module {
val data = new Array[Byte](32 * 1024 * 1024)
}

object Test {
private val readResolve = classOf[scala.runtime.ModuleSerializationProxy].getDeclaredMethod("readResolve")
readResolve.setAccessible(true)

val testClassesDir = new File(Module.getClass.getClassLoader.getResource("Module.class").toURI).getParentFile
def main(args: Array[String]): Unit = {
for (i <- 1 to 256) {
// This would "java.lang.OutOfMemoryError: Java heap space" if ModuleSerializationProxy
// prevented class unloading.
deserializeDynamicLoadedClass()
}
}

def deserializeDynamicLoadedClass(): Unit = {
val loader = new java.net.URLClassLoader(Array(testClassesDir.toURI.toURL), ClassLoader.getSystemClassLoader)
val moduleClass = loader.loadClass("Module$")
assert(moduleClass ne Module.getClass)
val result = readResolve.invoke(new scala.runtime.ModuleSerializationProxy(moduleClass))
assert(result.getClass == moduleClass)
}
}
14 changes: 14 additions & 0 deletions tests/run/serialize.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,23 @@ object Test {
in.readObject.asInstanceOf[T]
}

object Foo extends Serializable {}

object Baz extends Serializable {
private def writeReplace(): AnyRef = {
this
}
}

def main(args: Array[String]): Unit = {
val x: PartialFunction[Int, Int] = { case x => x + 1 }
val adder = serializeDeserialize(x)
assert(adder(1) == 2)

val foo = serializeDeserialize(Foo)
assert(foo eq Foo)

val baz = serializeDeserialize(Baz)
assert(baz ne Baz)
}
}
Loading