From ad88ddd7c84dbc48e2d5a26421d9319179dacc3c Mon Sep 17 00:00:00 2001 From: Austin Marshall Date: Thu, 9 Apr 2015 19:08:42 -0500 Subject: [PATCH] Cleanup multi encoder unit tests --- tests/unit/nupic/encoders/multi_test.py | 61 ++++++++++++------------- 1 file changed, 29 insertions(+), 32 deletions(-) mode change 100644 => 100755 tests/unit/nupic/encoders/multi_test.py diff --git a/tests/unit/nupic/encoders/multi_test.py b/tests/unit/nupic/encoders/multi_test.py old mode 100644 new mode 100755 index deac4faa93..300e30049a --- a/tests/unit/nupic/encoders/multi_test.py +++ b/tests/unit/nupic/encoders/multi_test.py @@ -27,57 +27,55 @@ import unittest2 as unittest from nupic.encoders.multi import MultiEncoder -from nupic.encoders import * +from nupic.encoders import ScalarEncoder, SDRCategoryEncoder from nupic.data.dictutils import DictObj from nupic.encoders.multi_capnp import MultiEncoderProto -######################################################################### class MultiEncoderTest(unittest.TestCase): - '''Unit tests for MultiEncoder class''' + """Unit tests for MultiEncoder class""" -########################################################################## def testMultiEncoder(self): """Testing MultiEncoder...""" e = MultiEncoder() # should be 7 bits wide - # use of forced=True is not recommended, but here for readibility, see scalar.py - e.addEncoder("dow", ScalarEncoder(w=3, resolution=1, minval=1, maxval=8, - periodic=True, name="day of week", forced=True)) + # use of forced=True is not recommended, but here for readibility, see + # scalar.py + e.addEncoder("dow", + ScalarEncoder(w=3, resolution=1, minval=1, maxval=8, + periodic=True, name="day of week", forced=True)) # sould be 14 bits wide - e.addEncoder("myval", ScalarEncoder(w=5, resolution=1, minval=1, maxval=10, - periodic=False, name="aux", forced=True)) + e.addEncoder("myval", + ScalarEncoder(w=5, resolution=1, minval=1, maxval=10, + periodic=False, name="aux", forced=True)) self.assertEqual(e.getWidth(), 21) self.assertEqual(e.getDescription(), [("day of week", 0), ("aux", 7)]) d = DictObj(dow=3, myval=10) - expected=numpy.array([0,1,1,1,0,0,0] + [0,0,0,0,0,0,0,0,0,1,1,1,1,1], dtype='uint8') + expected=numpy.array([0,1,1,1,0,0,0] + [0,0,0,0,0,0,0,0,0,1,1,1,1,1], + dtype="uint8") output = e.encode(d) - assert(expected == output).all() - - - e.pprintHeader() - e.pprint(output) + self.assertTrue(numpy.array_equal(expected, output)) # Check decoding decoded = e.decode(output) - #print decoded self.assertEqual(len(decoded), 2) - (ranges, desc) = decoded[0]['aux'] - self.assertTrue(len(ranges) == 1 and numpy.array_equal(ranges[0], [10, 10])) - (ranges, desc) = decoded[0]['day of week'] + (ranges, _) = decoded[0]["aux"] + self.assertEqual(len(ranges), 1) + self.assertTrue(numpy.array_equal(ranges[0], [10, 10])) + (ranges, _) = decoded[0]["day of week"] self.assertTrue(len(ranges) == 1 and numpy.array_equal(ranges[0], [3, 3])) - print "decodedToStr=>", e.decodedToStr(decoded) - e.addEncoder("myCat", SDRCategoryEncoder(n=7, w=3, - categoryList=["run", "pass","kick"], forced=True)) + e.addEncoder("myCat", + SDRCategoryEncoder(n=7, w=3, + categoryList=["run", "pass","kick"], + forced=True)) - print "\nTesting mixed multi-encoder" d = DictObj(dow=4, myval=6, myCat="pass") output = e.encode(d) topDownOut = e.topDownCompute(output) @@ -91,16 +89,16 @@ def testMultiEncoder(self): def testReadWrite(self): original = MultiEncoder() - original.addEncoder("dow", ScalarEncoder(w=3, resolution=1, minval=1, - maxval=8, periodic=True, name="day of week", - forced=True)) - - original.addEncoder("myval", ScalarEncoder(w=5, resolution=1, minval=1, - maxval=10, periodic=False, name="aux", forced=True)) + original.addEncoder("dow", + ScalarEncoder(w=3, resolution=1, minval=1, maxval=8, + periodic=True, name="day of week", + forced=True)) + original.addEncoder("myval", + ScalarEncoder(w=5, resolution=1, minval=1, maxval=10, + periodic=False, name="aux", forced=True)) originalValue = DictObj(dow=3, myval=10) output = original.encode(originalValue) - proto1 = MultiEncoderProto.new_message() original.write(proto1) @@ -129,6 +127,5 @@ def testReadWrite(self): -########################################### -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()