From d2fc35861f0d6bac8d83a47ec3f141c28410c5d3 Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Mon, 24 Mar 2025 11:15:00 +0100 Subject: [PATCH] Mix in the `productPrefix` hash statically in case class `hashCode` Since 2.13, case class `hashCode` mixes in the hash code of the `productPrefix` string. This is inconsistent with the `equals` method, subclasses of case classes that override `productPrefix` compare equal but have a different `hashCode`. This commit changes `hashCode` to mix in the `productPrefix.hashCode` statically instead of invoking `productPrefix` at runtime. For case classes without primitive fields, the synthetic `hashCode` invokes `ScalaRunTime._hashCode`, which mixes in the result of `productPrefix`. To fix that, the synthetic hashCode is changed to invoke `MurmurHash3.productHash` directly and mix in the name to the seed statically. --- .../dotty/tools/dotc/core/Definitions.scala | 3 + .../dotc/transform/SyntheticMembers.scala | 66 +++++++++--------- tests/run-macros/tasty-extractors-2.check | 2 +- tests/run/t13033.scala | 69 +++++++++++++++++++ 4 files changed, 105 insertions(+), 35 deletions(-) create mode 100644 tests/run/t13033.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index c1939d6f8fa6..5dec60281847 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -505,6 +505,9 @@ class Definitions { @tu lazy val ScalaRuntime_toArray: Symbol = ScalaRuntimeModule.requiredMethod(nme.toArray) @tu lazy val ScalaRuntime_toObjectArray: Symbol = ScalaRuntimeModule.requiredMethod(nme.toObjectArray) + @tu lazy val MurmurHash3Module: Symbol = requiredModule("scala.util.hashing.MurmurHash3") + @tu lazy val MurmurHash3_productHash = MurmurHash3Module.info.member(termName("productHash")).suchThat(_.info.firstParamTypes.size == 3).symbol + @tu lazy val BoxesRunTimeModule: Symbol = requiredModule("scala.runtime.BoxesRunTime") @tu lazy val BoxesRunTimeModule_externalEquals: Symbol = BoxesRunTimeModule.info.decl(nme.equals_).suchThat(toDenot(_).info.firstParamTypes.size == 2).symbol @tu lazy val ScalaStaticsModule: Symbol = requiredModule("scala.runtime.Statics") diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 926a19224e79..73eb69be92a7 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -15,6 +15,7 @@ import util.Property import util.Spans.Span import config.Printers.derive import NullOpsDecorator.* +import scala.runtime.Statics object SyntheticMembers { @@ -101,6 +102,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { val isSimpleEnumValue = isEnumValue && !clazz.owner.isAllOf(EnumCase) val isJavaEnumValue = isEnumValue && clazz.derivesFrom(defn.JavaEnumClass) val isNonJavaEnumValue = isEnumValue && !isJavaEnumValue + val ownName = clazz.name.stripModuleClassSuffix.toString val symbolsToSynthesize: List[Symbol] = if clazz.is(Case) then @@ -124,8 +126,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { def forwardToRuntime(vrefs: List[Tree]): Tree = ref(defn.runtimeMethodRef("_" + sym.name.toString)).appliedToTermArgs(This(clazz) :: vrefs) - def ownName: Tree = - Literal(Constant(clazz.name.stripModuleClassSuffix.toString)) + def ownNameLit: Tree = Literal(Constant(ownName)) def nameRef: Tree = if isJavaEnumValue then @@ -152,7 +153,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { Literal(Constant(candidate.get)) def toStringBody(vrefss: List[List[Tree]]): Tree = - if (clazz.is(ModuleClass)) ownName + if (clazz.is(ModuleClass)) ownNameLit else if (isNonJavaEnumValue) identifierRef else forwardToRuntime(vrefss.head) @@ -165,7 +166,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { case nme.ordinal => ordinalRef case nme.productArity => Literal(Constant(accessors.length)) case nme.productPrefix if isEnumValue => nameRef - case nme.productPrefix => ownName + case nme.productPrefix => ownNameLit case nme.productElement => if ctx.settings.YcompileScala2Library.value then productElementBodyForScala2Compat(accessors.length, vrefss.head.head) else productElementBody(accessors.length, vrefss.head.head) @@ -335,39 +336,36 @@ class SyntheticMembers(thisPhase: DenotTransformer) { ref(accessors.head).select(nme.hashCode_).ensureApplied } - /** The class - * - * ``` - * case object C - * ``` - * - * gets the `hashCode` method: - * - * ``` - * def hashCode: Int = "C".hashCode // constant folded - * ``` - * - * The class - * - * ``` - * case class C(x: T, y: U) - * ``` - * - * if none of `T` or `U` are primitive types, gets the `hashCode` method: - * - * ``` - * def hashCode: Int = ScalaRunTime._hashCode(this) - * ``` - * - * else if either `T` or `U` are primitive, gets the `hashCode` method implemented by [[caseHashCodeBody]] + /** + * A `case object C` or a `case class C()` without parameters gets the `hashCode` method + * ``` + * def hashCode: Int = "C".hashCode // constant folded + * ``` + * + * Otherwise, if none of the parameters are primitive types: + * ``` + * def hashCode: Int = MurmurHash3.productHash( + * this, + * Statics.mix(0xcafebabe, "C".hashCode), // constant folded + * ignorePrefix = true) + * ``` + * + * The implementation used to invoke `ScalaRunTime._hashCode`, but that implementation mixes in the result + * of `productPrefix`, which causes scala/bug#13033. By setting `ignorePrefix = true` and mixing in the case + * name into the seed, the bug can be fixed and the generated code works with the unchanged Scala library. + * + * For case classes with primitive paramters, see [[caseHashCodeBody]]. */ def chooseHashcode(using Context) = - if (clazz.is(ModuleClass)) - Literal(Constant(clazz.name.stripModuleClassSuffix.toString.hashCode)) + if (accessors.isEmpty) Literal(Constant(ownName.hashCode)) else if (accessors.exists(_.info.finalResultType.classSymbol.isPrimitiveValueClass)) caseHashCodeBody else - ref(defn.ScalaRuntime__hashCode).appliedTo(This(clazz)) + ref(defn.MurmurHash3Module).select(defn.MurmurHash3_productHash).appliedTo( + This(clazz), + Literal(Constant(Statics.mix(0xcafebabe, ownName.hashCode))), + Literal(Constant(true)) + ) /** The class * @@ -380,7 +378,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { * ``` * def hashCode: Int = { * var acc: Int = 0xcafebabe - * acc = Statics.mix(acc, this.productPrefix.hashCode()); + * acc = Statics.mix(acc, "C".hashCode); * acc = Statics.mix(acc, x); * acc = Statics.mix(acc, Statics.this.anyHash(y)); * Statics.finalizeHash(acc, 2) @@ -391,7 +389,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { val acc = newSymbol(ctx.owner, nme.acc, Mutable | Synthetic, defn.IntType, coord = ctx.owner.span) val accDef = ValDef(acc, Literal(Constant(0xcafebabe))) val mixPrefix = Assign(ref(acc), - ref(defn.staticsMethod("mix")).appliedTo(ref(acc), This(clazz).select(defn.Product_productPrefix).select(defn.Any_hashCode).appliedToNone)) + ref(defn.staticsMethod("mix")).appliedTo(ref(acc), Literal(Constant(ownName.hashCode)))) val mixes = for (accessor <- accessors) yield Assign(ref(acc), ref(defn.staticsMethod("mix")).appliedTo(ref(acc), hashImpl(accessor))) val finish = ref(defn.staticsMethod("finalizeHash")).appliedTo(ref(acc), Literal(Constant(accessors.size))) diff --git a/tests/run-macros/tasty-extractors-2.check b/tests/run-macros/tasty-extractors-2.check index 15d844670b7a..0a9dd1e94eab 100644 --- a/tests/run-macros/tasty-extractors-2.check +++ b/tests/run-macros/tasty-extractors-2.check @@ -49,7 +49,7 @@ TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit") Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil)), None, List(DefDef("a", Nil, Inferred(), Some(Literal(IntConstant(0))))))), Literal(UnitConstant()))) TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit") -Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), TypeSelect(Select(Ident("_root_"), "scala"), "Product"), TypeSelect(Select(Ident("_root_"), "scala"), "Serializable")), None, List(DefDef("hashCode", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_hashCode"), List(This(Some("Foo")))))), DefDef("equals", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(Apply(Select(This(Some("Foo")), "eq"), List(TypeApply(Select(Ident("x$0"), "$asInstanceOf$"), List(Inferred())))), "||"), List(Match(Ident("x$0"), List(CaseDef(Bind("x$0", Typed(Wildcard(), Inferred())), None, Apply(Select(Literal(BooleanConstant(true)), "&&"), List(Apply(Select(Ident("x$0"), "canEqual"), List(This(Some("Foo"))))))), CaseDef(Wildcard(), None, Literal(BooleanConstant(false))))))))), DefDef("toString", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_toString"), List(This(Some("Foo")))))), DefDef("canEqual", List(TermParamClause(List(ValDef("that", Inferred(), None)))), Inferred(), Some(TypeApply(Select(Ident("that"), "isInstanceOf"), List(Inferred())))), DefDef("productArity", Nil, Inferred(), Some(Literal(IntConstant(0)))), DefDef("productPrefix", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), DefDef("productElement", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("copy", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))))), ValDef("Foo", TypeIdent("Foo$"), Some(Apply(Select(New(TypeIdent("Foo$")), ""), Nil))), ClassDef("Foo$", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), Inferred()), Some(ValDef("_", Singleton(Ident("Foo")), None)), List(DefDef("apply", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))), DefDef("unapply", List(TermParamClause(List(ValDef("x$1", Inferred(), None)))), Singleton(Literal(BooleanConstant(true))), Some(Literal(BooleanConstant(true)))), DefDef("toString", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), TypeDef("MirroredMonoType", TypeBoundsTree(Inferred(), Inferred())), DefDef("fromProduct", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Block(Nil, Apply(Select(New(Inferred()), ""), Nil))))))), Literal(UnitConstant()))) +Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), TypeSelect(Select(Ident("_root_"), "scala"), "Product"), TypeSelect(Select(Ident("_root_"), "scala"), "Serializable")), None, List(DefDef("hashCode", List(TermParamClause(Nil)), Inferred(), Some(Literal(IntConstant(70822)))), DefDef("equals", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(Apply(Select(This(Some("Foo")), "eq"), List(TypeApply(Select(Ident("x$0"), "$asInstanceOf$"), List(Inferred())))), "||"), List(Match(Ident("x$0"), List(CaseDef(Bind("x$0", Typed(Wildcard(), Inferred())), None, Apply(Select(Literal(BooleanConstant(true)), "&&"), List(Apply(Select(Ident("x$0"), "canEqual"), List(This(Some("Foo"))))))), CaseDef(Wildcard(), None, Literal(BooleanConstant(false))))))))), DefDef("toString", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_toString"), List(This(Some("Foo")))))), DefDef("canEqual", List(TermParamClause(List(ValDef("that", Inferred(), None)))), Inferred(), Some(TypeApply(Select(Ident("that"), "isInstanceOf"), List(Inferred())))), DefDef("productArity", Nil, Inferred(), Some(Literal(IntConstant(0)))), DefDef("productPrefix", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), DefDef("productElement", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("copy", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))))), ValDef("Foo", TypeIdent("Foo$"), Some(Apply(Select(New(TypeIdent("Foo$")), ""), Nil))), ClassDef("Foo$", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), Inferred()), Some(ValDef("_", Singleton(Ident("Foo")), None)), List(DefDef("apply", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))), DefDef("unapply", List(TermParamClause(List(ValDef("x$1", Inferred(), None)))), Singleton(Literal(BooleanConstant(true))), Some(Literal(BooleanConstant(true)))), DefDef("toString", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), TypeDef("MirroredMonoType", TypeBoundsTree(Inferred(), Inferred())), DefDef("fromProduct", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Block(Nil, Apply(Select(New(Inferred()), ""), Nil))))))), Literal(UnitConstant()))) TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit") Inlined(None, Nil, Block(List(ClassDef("Foo1", DefDef("", List(TermParamClause(List(ValDef("a", TypeIdent("Int"), None)))), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil)), None, List(ValDef("a", Inferred(), None)))), Literal(UnitConstant()))) diff --git a/tests/run/t13033.scala b/tests/run/t13033.scala new file mode 100644 index 000000000000..43f227990011 --- /dev/null +++ b/tests/run/t13033.scala @@ -0,0 +1,69 @@ +// This method will be in the 2.13.17 standard library. Until this test declares a copy of it. +// import scala.util.hashing.MurmurHash3.caseClassHash +def caseClassHash(x: Product, caseClassName: String = null): Int = + import scala.runtime.Statics._ + val arr = x.productArity + val aye = (if (caseClassName != null) caseClassName else x.productPrefix).hashCode + if (arr == 0) aye + else { + var h = 0xcafebabe + h = mix(h, aye) + var i = 0 + while (i < arr) { + h = mix(h, x.productElement(i).##) + i += 1 + } + finalizeHash(h, arr) + } + + +case class C1(a: Int) +class C2(a: Int) extends C1(a) { override def productPrefix = "C2" } +class C3(a: Int) extends C1(a) { override def productPrefix = "C3" } +case class C4(a: Int) { override def productPrefix = "Sea4" } +case class C5() +case object C6 +case object C6b { override def productPrefix = "Sea6b" } +case class C7(s: String) // hashCode forwards to ScalaRunTime._hashCode if there are no primitives +class C8(s: String) extends C7(s) { override def productPrefix = "C8" } + +case class VCC(x: Int) extends AnyVal + +object Test extends App { + val c1 = C1(1) + val c2 = new C2(1) + val c3 = new C3(1) + assert(c1 == c2) + assert(c2 == c1) + assert(c2 == c3) + assert(c1.hashCode == c2.hashCode) + assert(c2.hashCode == c3.hashCode) + + assert(c1.hashCode == caseClassHash(c1)) + // `caseClassHash` mixes in the `productPrefix.hashCode`, while `hashCode` mixes in the case class name statically + assert(c2.hashCode != caseClassHash(c2)) + assert(c2.hashCode == caseClassHash(c2, c1.productPrefix)) + + val c4 = C4(1) + assert(c4.hashCode != caseClassHash(c4)) + assert(c4.hashCode == caseClassHash(c4, "C4")) + + assert((1, 2).hashCode == caseClassHash(1 -> 2)) + assert(("", "").hashCode == caseClassHash("" -> "")) + + assert(C5().hashCode == caseClassHash(C5())) + assert(C6.hashCode == caseClassHash(C6)) + assert(C6b.hashCode == caseClassHash(C6b, "C6b")) + + val c7 = C7("hi") + val c8 = new C8("hi") + assert(c7.hashCode == caseClassHash(c7)) + assert(c7 == c8) + assert(c7.hashCode == c8.hashCode) + assert(c8.hashCode != caseClassHash(c8)) + assert(c8.hashCode == caseClassHash(c8, "C7")) + + + assert(VCC(1).canEqual(VCC(1))) + assert(!VCC(1).canEqual(1)) +}