Skip to content

Commit

Permalink
Merge pull request numenta#2033 from oxtopus/issue-2032-cleanup-multi…
Browse files Browse the repository at this point in the history
…-unit-tests

Cleanup multi encoder unit tests
  • Loading branch information
rhyolight committed May 6, 2015
2 parents 73770a8 + ad88ddd commit e373be3
Showing 1 changed file with 29 additions and 32 deletions.
61 changes: 29 additions & 32 deletions tests/unit/nupic/encoders/multi_test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,57 +27,55 @@
import unittest2 as unittest

from nupic.encoders.multi import MultiEncoder
from nupic.encoders import *
from nupic.encoders import ScalarEncoder, SDRCategoryEncoder
from nupic.data.dictutils import DictObj

from nupic.encoders.multi_capnp import MultiEncoderProto



#########################################################################
class MultiEncoderTest(unittest.TestCase):
'''Unit tests for MultiEncoder class'''
"""Unit tests for MultiEncoder class"""


##########################################################################
def testMultiEncoder(self):
"""Testing MultiEncoder..."""

e = MultiEncoder()

# should be 7 bits wide
# use of forced=True is not recommended, but here for readibility, see scalar.py
e.addEncoder("dow", ScalarEncoder(w=3, resolution=1, minval=1, maxval=8,
periodic=True, name="day of week", forced=True))
# use of forced=True is not recommended, but here for readibility, see
# scalar.py
e.addEncoder("dow",
ScalarEncoder(w=3, resolution=1, minval=1, maxval=8,
periodic=True, name="day of week", forced=True))
# sould be 14 bits wide
e.addEncoder("myval", ScalarEncoder(w=5, resolution=1, minval=1, maxval=10,
periodic=False, name="aux", forced=True))
e.addEncoder("myval",
ScalarEncoder(w=5, resolution=1, minval=1, maxval=10,
periodic=False, name="aux", forced=True))
self.assertEqual(e.getWidth(), 21)
self.assertEqual(e.getDescription(), [("day of week", 0), ("aux", 7)])

d = DictObj(dow=3, myval=10)
expected=numpy.array([0,1,1,1,0,0,0] + [0,0,0,0,0,0,0,0,0,1,1,1,1,1], dtype='uint8')
expected=numpy.array([0,1,1,1,0,0,0] + [0,0,0,0,0,0,0,0,0,1,1,1,1,1],
dtype="uint8")
output = e.encode(d)
assert(expected == output).all()


e.pprintHeader()
e.pprint(output)
self.assertTrue(numpy.array_equal(expected, output))

# Check decoding
decoded = e.decode(output)
#print decoded
self.assertEqual(len(decoded), 2)
(ranges, desc) = decoded[0]['aux']
self.assertTrue(len(ranges) == 1 and numpy.array_equal(ranges[0], [10, 10]))
(ranges, desc) = decoded[0]['day of week']
(ranges, _) = decoded[0]["aux"]
self.assertEqual(len(ranges), 1)
self.assertTrue(numpy.array_equal(ranges[0], [10, 10]))
(ranges, _) = decoded[0]["day of week"]
self.assertTrue(len(ranges) == 1 and numpy.array_equal(ranges[0], [3, 3]))
print "decodedToStr=>", e.decodedToStr(decoded)

e.addEncoder("myCat", SDRCategoryEncoder(n=7, w=3,
categoryList=["run", "pass","kick"], forced=True))
e.addEncoder("myCat",
SDRCategoryEncoder(n=7, w=3,
categoryList=["run", "pass","kick"],
forced=True))

print "\nTesting mixed multi-encoder"
d = DictObj(dow=4, myval=6, myCat="pass")
output = e.encode(d)
topDownOut = e.topDownCompute(output)
Expand All @@ -91,16 +89,16 @@ def testMultiEncoder(self):

def testReadWrite(self):
original = MultiEncoder()
original.addEncoder("dow", ScalarEncoder(w=3, resolution=1, minval=1,
maxval=8, periodic=True, name="day of week",
forced=True))

original.addEncoder("myval", ScalarEncoder(w=5, resolution=1, minval=1,
maxval=10, periodic=False, name="aux", forced=True))
original.addEncoder("dow",
ScalarEncoder(w=3, resolution=1, minval=1, maxval=8,
periodic=True, name="day of week",
forced=True))
original.addEncoder("myval",
ScalarEncoder(w=5, resolution=1, minval=1, maxval=10,
periodic=False, name="aux", forced=True))
originalValue = DictObj(dow=3, myval=10)
output = original.encode(originalValue)


proto1 = MultiEncoderProto.new_message()
original.write(proto1)

Expand Down Expand Up @@ -129,6 +127,5 @@ def testReadWrite(self):



###########################################
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit e373be3

Please sign in to comment.