Skip to content

Equality cleanup: #678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions jvm/src/test/scala/scala/xml/XMLTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package scala.xml

import org.junit.{Test => UnitTest}
import org.junit.Assert.{assertEquals, assertFalse, assertNull, assertThrows, assertTrue}
import java.io.StringWriter
import java.io.ByteArrayOutputStream
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStreamReader, IOException, ObjectInputStream,
ObjectOutputStream, OutputStreamWriter, PrintStream, StringWriter}
import java.net.URL
import scala.xml.dtd.{DocType, PublicID}
import scala.xml.parsing.ConstructingParser
Expand Down Expand Up @@ -177,26 +177,26 @@ class XMLTestJVM {
</entry>""", f("a,b,c").toString)

object Serialize {
@throws(classOf[java.io.IOException])
@throws(classOf[IOException])
def write[A](o: A): Array[Byte] = {
val ba: ByteArrayOutputStream = new ByteArrayOutputStream(512)
val out: java.io.ObjectOutputStream = new java.io.ObjectOutputStream(ba)
val out: ObjectOutputStream = new ObjectOutputStream(ba)
out.writeObject(o)
out.close()
ba.toByteArray
}
@throws(classOf[java.io.IOException])
@throws(classOf[IOException])
@throws(classOf[ClassNotFoundException])
def read[A](buffer: Array[Byte]): A = {
val in: java.io.ObjectInputStream =
new java.io.ObjectInputStream(new java.io.ByteArrayInputStream(buffer))
val in: ObjectInputStream =
new ObjectInputStream(new ByteArrayInputStream(buffer))
in.readObject().asInstanceOf[A]
}
def check[A, B](x: A, y: B): Unit = {
// println("x = " + x)
// println("y = " + y)
// println("x equals y: " + (x equals y) + ", y equals x: " + (y equals x))
assertTrue(x.equals(y) && y.equals(x))
// println("x == y: " + (x == y) + ", y == x: " + (y == x))
assertTrue(x == y && y == x)
// println()
}
}
Expand Down Expand Up @@ -296,14 +296,14 @@ class XMLTestJVM {
// scala.xml.XML.save("foo.xml", xml)
// scala.xml.XML.loadFile("foo.xml").toString

val outputStream: java.io.ByteArrayOutputStream = new java.io.ByteArrayOutputStream
val streamWriter: java.io.OutputStreamWriter = new java.io.OutputStreamWriter(outputStream, "UTF-8")
val outputStream: ByteArrayOutputStream = new ByteArrayOutputStream
val streamWriter: OutputStreamWriter = new OutputStreamWriter(outputStream, "UTF-8")

XML.write(streamWriter, xml, XML.encoding, xmlDecl = false, null)
streamWriter.flush()

val inputStream: java.io.ByteArrayInputStream = new java.io.ByteArrayInputStream(outputStream.toByteArray)
val streamReader: java.io.InputStreamReader = new java.io.InputStreamReader(inputStream, XML.encoding)
val inputStream: ByteArrayInputStream = new ByteArrayInputStream(outputStream.toByteArray)
val streamReader: InputStreamReader = new InputStreamReader(inputStream, XML.encoding)

assertEquals(xml.toString, XML.load(streamReader).toString)
}
Expand Down Expand Up @@ -492,8 +492,6 @@ class XMLTestJVM {

@UnitTest
def dontLoop(): Unit = {
import java.io.{ Console => _, _ }

val xml: String = "<!DOCTYPE xmeml SYSTEM 'uri'> <xmeml> <sequence> </sequence> </xmeml> "
val sink: PrintStream = new PrintStream(new ByteArrayOutputStream())
Console.withOut(sink) {
Expand Down Expand Up @@ -1026,7 +1024,7 @@ class XMLTestJVM {

def toSource(s: String): scala.io.Source = new scala.io.Source {
override val iter: Iterator[Char] = s.iterator
override def reportError(pos: Int, msg: String, out: java.io.PrintStream = Console.err): Unit = ()
override def reportError(pos: Int, msg: String, out: PrintStream = Console.err): Unit = ()
}

@UnitTest
Expand Down
6 changes: 2 additions & 4 deletions shared/src/main/scala/scala/xml/Attribute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,13 @@ trait Attribute extends MetaData {
}

/** Returns an iterator on attributes */
override def iterator: Iterator[MetaData] = {
override def iterator: Iterator[MetaData] =
if (value == null) next.iterator
else Iterator.single(this) ++ next.iterator
}

override def size: Int = {
override def size: Int =
if (value == null) next.size
else 1 + next.size
}

/**
* Appends string representation of only this attribute to stringbuffer.
Expand Down
7 changes: 3 additions & 4 deletions shared/src/main/scala/scala/xml/Comment.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ case class Comment(commentText: String) extends SpecialNode {
final override def doCollectNamespaces: Boolean = false
final override def doTransform: Boolean = false

if (commentText.contains("--")) {
if (commentText.contains("--"))
throw new IllegalArgumentException(s"""text contains "--"""")
}
if (commentText.nonEmpty && commentText.charAt(commentText.length - 1) == '-') {

if (commentText.nonEmpty && commentText.charAt(commentText.length - 1) == '-')
throw new IllegalArgumentException("The final character of a XML comment may not be '-'. See https://www.w3.org/TR/xml11//#IDA5CES")
}

/**
* Appends &quot;<!-- text -->&quot; to this string buffer.
Expand Down
13 changes: 4 additions & 9 deletions shared/src/main/scala/scala/xml/Equality.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,13 @@ object Equality {
case x: NodeSeq if x.length == 1 => x2 == x(0)
case _ => false
}
def compareBlithely(x1: AnyRef, x2: AnyRef): Boolean = {
if (x1 == null || x2 == null)
return x1.eq(x2)

x2 match {
def compareBlithely(x1: AnyRef, x2: AnyRef): Boolean =
if (x1 == null || x2 == null) x1 == null && x2 == null else x2 match {
case s: String => compareBlithely(x1, s)
case n: Node => compareBlithely(x1, n)
case _ => false
}
}
}
import Equality._

trait Equality extends scala.Equals {
protected def basisForHashCode: Seq[Any]
Expand Down Expand Up @@ -109,10 +104,10 @@ trait Equality extends scala.Equals {
private def doComparison(other: Any, blithe: Boolean): Boolean = {
val strictlyEqual: Boolean = other match {
case x: AnyRef if this.eq(x) => true
case x: Equality => (x canEqual this) && (this strict_== x)
case x: Equality => x.canEqual(this) && this.strict_==(x)
case _ => false
}

strictlyEqual || (blithe && compareBlithely(this, asRef(other)))
strictlyEqual || (blithe && Equality.compareBlithely(this, Equality.asRef(other)))
}
}
21 changes: 11 additions & 10 deletions shared/src/main/scala/scala/xml/MetaData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,27 @@ object MetaData {
*/
@tailrec
def concatenate(attribs: MetaData, new_tail: MetaData): MetaData =
if (attribs.eq(Null)) new_tail
if (attribs.isNull) new_tail
else concatenate(attribs.next, attribs.copy(new_tail))

/**
* returns normalized MetaData, with all duplicates removed and namespace prefixes resolved to
* namespace URIs via the given scope.
*/
def normalize(attribs: MetaData, scope: NamespaceBinding): MetaData = {
def iterate(md: MetaData, normalized_attribs: MetaData, set: Set[String]): MetaData = {
if (md.eq(Null)) {
def iterate(md: MetaData, normalized_attribs: MetaData, set: Set[String]): MetaData =
if (md.isNull) {
normalized_attribs
} else if (md.value.eq(null)) {
} else if (md.value == null)
iterate(md.next, normalized_attribs, set)
} else {
else {
val key: String = getUniversalKey(md, scope)
if (set(key)) {
if (set(key))
iterate(md.next, normalized_attribs, set)
} else {
else
md.copy(iterate(md.next, normalized_attribs, set + key))
}
}
}

iterate(attribs, Null, Set())
}

Expand Down Expand Up @@ -85,7 +84,9 @@ abstract class MetaData
extends AbstractIterable[MetaData]
with Iterable[MetaData]
with Equality
with Serializable {
with Serializable
{
private[xml] def isNull: Boolean = this.eq(Null)

/**
* Updates this MetaData with the MetaData given as argument. All attributes that occur in updates
Expand Down
11 changes: 5 additions & 6 deletions shared/src/main/scala/scala/xml/NamespaceBinding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ case class NamespaceBinding(prefix: String, uri: String, parent: NamespaceBindin
if (prefix == "")
throw new IllegalArgumentException("zero length prefix not allowed")

def getURI(_prefix: String): String =
if (prefix == _prefix) uri else parent.getURI(_prefix)
def getURI(prefix: String): String =
if (this.prefix == prefix) uri else parent.getURI(prefix)

/**
* Returns some prefix that is mapped to the URI.
Expand All @@ -39,8 +39,8 @@ case class NamespaceBinding(prefix: String, uri: String, parent: NamespaceBindin
* @return the prefix that is mapped to the input URI, or null
* if no prefix is mapped to the URI.
*/
def getPrefix(_uri: String): String =
if (_uri == uri) prefix else parent.getPrefix(_uri)
def getPrefix(uri: String): String =
if (uri == this.uri) prefix else parent.getPrefix(uri)

override def toString: String = Utility.sbToString(buildString(_, TopScope))

Expand Down Expand Up @@ -72,9 +72,8 @@ case class NamespaceBinding(prefix: String, uri: String, parent: NamespaceBindin

def buildString(stop: NamespaceBinding): String = Utility.sbToString(buildString(_, stop))

def buildString(sb: StringBuilder, stop: NamespaceBinding): Unit = {
def buildString(sb: StringBuilder, stop: NamespaceBinding): Unit =
shadowRedefined(stop).doBuildString(sb, stop)
}

private def doBuildString(sb: StringBuilder, stop: NamespaceBinding): Unit = {
if (List(null, stop, TopScope).contains(this)) return
Expand Down
4 changes: 2 additions & 2 deletions shared/src/main/scala/scala/xml/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ abstract class Node extends NodeSeq {
* @return the namespace if `scope != null` and prefix was
* found, else `null`
*/
def getNamespace(pre: String): String = if (scope.eq(null)) null else scope.getURI(pre)
def getNamespace(pre: String): String = if (scope == null) null else scope.getURI(pre)

/**
* Convenience method, looks up an unprefixed attribute in attributes of this node.
Expand Down Expand Up @@ -125,7 +125,7 @@ abstract class Node extends NodeSeq {
/**
* Children which do not stringify to "" (needed for equality)
*/
def nonEmptyChildren: Seq[Node] = child.filterNot(_.toString == "")
def nonEmptyChildren: Seq[Node] = child.filterNot(_.toString.isEmpty)

/**
* Descendant axis (all descendants of this node, not including node itself)
Expand Down
2 changes: 1 addition & 1 deletion shared/src/main/scala/scala/xml/NodeSeq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S
val i: Int = that.indexOf('}')
if (i == -1) fail
val (uri: String, key: String) = (that.substring(2, i), that.substring(i + 1, that.length))
if (uri == "" || key == "") fail
if (uri.isEmpty || key.isEmpty) fail
else y.attribute(uri, key)
} else y.attribute(that.drop(1))

Expand Down
13 changes: 7 additions & 6 deletions shared/src/main/scala/scala/xml/PrefixedAttribute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ class PrefixedAttribute(
override val pre: String,
override val key: String,
override val value: Seq[Node],
val next1: MetaData)
extends Attribute {
override val next: MetaData = if (value.ne(null)) next1 else next1.remove(key)
val next1: MetaData
)
extends Attribute
{
override val next: MetaData = if (value != null) next1 else next1.remove(key)

/** same as this(pre, key, Text(value), next), or no attribute if value is null */
def this(pre: String, key: String, value: String, next: MetaData) =
this(pre, key, if (value.ne(null)) Text(value) else null: NodeSeq, next)
this(pre, key, if (value != null) Text(value) else null: NodeSeq, next)

/** same as this(pre, key, value.get, next), or no attribute if value is None */
def this(pre: String, key: String, value: Option[Seq[Node]], next: MetaData) =
Expand All @@ -56,12 +58,11 @@ class PrefixedAttribute(
/**
* gets attribute value of qualified (prefixed) attribute with given key
*/
override def apply(namespace: String, scope: NamespaceBinding, key: String): Seq[Node] = {
override def apply(namespace: String, scope: NamespaceBinding, key: String): Seq[Node] =
if (key == this.key && scope.getURI(pre) == namespace)
value
else
next(namespace, scope, key)
}
}

object PrefixedAttribute {
Expand Down
13 changes: 5 additions & 8 deletions shared/src/main/scala/scala/xml/PrettyPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class PrettyPrinter(width: Int, step: Int, minimizeEmpty: Boolean) {

protected def traverse(node: Node, pscope: NamespaceBinding, ind: Int): Unit = node match {

case Text(s) if s.trim == "" =>
case Text(s) if s.trim.isEmpty =>

case _: Atom[_] | _: Comment | _: EntityRef | _: ProcInstr =>
makeBox(ind, node.toString.trim)
Expand All @@ -163,18 +163,17 @@ class PrettyPrinter(width: Int, step: Int, minimizeEmpty: Boolean) {
if (doPreserve(node)) sb.toString
else TextBuffer.fromString(sb.toString).toText(0).data
}
if (childrenAreLeaves(node) && fits(test)) {
if (childrenAreLeaves(node) && fits(test))
makeBox(ind, test)
} else {
else {
val ((stg: String, len2: Int), etg: String) =
if (node.child.isEmpty && minimizeEmpty) {
// force the tag to be self-closing
val firstAttribute: Int = test.indexOf(' ')
val firstBreak: Int = if (firstAttribute != -1) firstAttribute else test.lastIndexOf('/')
((test, firstBreak), "")
} else {
} else
(startTag(node, pscope), endTag(node))
}

if (stg.length < width - cur) { // start tag fits
makeBox(ind, stg)
Expand Down Expand Up @@ -221,9 +220,7 @@ class PrettyPrinter(width: Int, step: Int, minimizeEmpty: Boolean) {
* @param n the node to be serialized
* @param sb the stringbuffer to append to
*/
def format(n: Node, sb: StringBuilder): Unit = { // entry point
format(n, TopScope, sb)
}
def format(n: Node, sb: StringBuilder): Unit = format(n, TopScope, sb) // entry point

def format(n: Node, pscope: NamespaceBinding, sb: StringBuilder): Unit = { // entry point
var lastwasbreak: Boolean = false
Expand Down
6 changes: 2 additions & 4 deletions shared/src/main/scala/scala/xml/ProcInstr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ case class ProcInstr(target: String, proctext: String) extends SpecialNode {
* appends &quot;&lt;?&quot; target (&quot; &quot;+text)?+&quot;?&gt;&quot;
* to this stringbuffer.
*/
override def buildString(sb: StringBuilder): StringBuilder = {
val textStr: String = if (proctext == "") "" else s" $proctext"
sb.append(s"<?$target$textStr?>")
}
override def buildString(sb: StringBuilder): StringBuilder =
sb.append(s"<?$target${if (proctext.isEmpty) "" else s" $proctext"}?>")
}
6 changes: 2 additions & 4 deletions shared/src/main/scala/scala/xml/TopScope.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ package xml
*/
object TopScope extends NamespaceBinding(null, null, null) {

import XML.{xml, namespace}

override def getURI(prefix1: String): String =
if (prefix1 == xml) namespace else null
if (prefix1 == XML.xml) XML.namespace else null

override def getPrefix(uri1: String): String =
if (uri1 == namespace) xml else null
if (uri1 == XML.namespace) XML.xml else null

override def toString: String = ""

Expand Down
10 changes: 6 additions & 4 deletions shared/src/main/scala/scala/xml/UnprefixedAttribute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ import scala.collection.Seq
class UnprefixedAttribute(
override val key: String,
override val value: Seq[Node],
next1: MetaData)
extends Attribute {
next1: MetaData
)
extends Attribute
{
final override val pre: scala.Null = null
override val next: MetaData = if (value.ne(null)) next1 else next1.remove(key)
override val next: MetaData = if (value != null) next1 else next1.remove(key)

/** same as this(key, Text(value), next), or no attribute if value is null */
def this(key: String, value: String, next: MetaData) =
this(key, if (value.ne(null)) Text(value) else null: NodeSeq, next)
this(key, if (value != null) Text(value) else null: NodeSeq, next)

/** same as this(key, value.get, next), or no attribute if value is None */
def this(key: String, value: Option[Seq[Node]], next: MetaData) =
Expand Down
Loading