Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions python/pyspark/mllib/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import sys
import array
import struct

if sys.version >= '3':
basestring = str
Expand Down Expand Up @@ -122,6 +123,15 @@ def _format_float_list(l):
return [_format_float(x) for x in l]


def _double_to_long_bits(value):
if value != value:
# value is NaN, standardize to canonical non-signaling NaN
return 0x7ff8000000000000
else:
# pack double into 64 bits, then unpack as long int
return struct.unpack('Q', struct.pack('d', value))[0]


class VectorUDT(UserDefinedType):
"""
SQL user-defined type (UDT) for Vector.
Expand Down Expand Up @@ -409,6 +419,34 @@ def __eq__(self, other):
def __ne__(self, other):
return not self == other

def __hash__(self):
"""
Compute hashcode

>>> v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
>>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
>>> hash(v1) == hash(v2)
True
>>> v2 = DenseVector([0.0, 1.0, 0.0, 5.5])
>>> hash(v1) == hash(v2)
True
>>> v2 = DenseVector([1.0, 1.0, 0.0, 5.5])
>>> hash(v1) == hash(v2)
False
"""
size = len(self)
result = 31 + size
count = 0
i = 0
while i < size and count < 16:
if self.array[i] != 0:
bits = _double_to_long_bits(self.array[i] + i)
result = 31 * result + (bits ^ (bits >> 32))

count += 1
i += 1
return result

def __getattr__(self, item):
return getattr(self.array, item)

Expand Down Expand Up @@ -739,6 +777,33 @@ def __getitem__(self, index):
def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
"""
Compute hashcode

>>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)])
>>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
>>> hash(v1) == hash(v2)
True
>>> v2 = SparseVector(4, [(1, 1.0), (3, 2.5)])
>>> hash(v1) == hash(v2)
False
>>> v2 = SparseVector(4, [(2, 1.0), (3, 5.5)])
>>> hash(v1) == hash(v2)
False
"""
result = 31 + self.size
count = 0
i = 0
while i < len(self.values) and count < 16:
if self.values[i] != 0:
bits = _double_to_long_bits(self.values[i] + self.indices[i])
result = 31 * result + (bits ^ (bits >> 32))

count += 1
i += 1
return result


class Vectors(object):

Expand Down