Skip to content

Commit

Permalink
add speical cases for single symbol Huffman codes, see #172
Browse files Browse the repository at this point in the history
  • Loading branch information
ilanschnell committed Apr 20, 2022
1 parent 7fe1a02 commit 7d80bae
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Once you have installed the package, you may want to test it:
.........................................................................
................................................................
----------------------------------------------------------------------
Ran 423 tests in 0.515s
Ran 424 tests in 0.515s
OK
Expand Down
38 changes: 26 additions & 12 deletions bitarray/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,16 +1473,6 @@ def test_simple(self):
self.assertEqual(len(code['as']), 2)
self.assertEqual(len(code[None]), 2)

def test_tiny(self):
code = huffman_code({0: 0})
self.assertEqual(len(code), 1)
self.assertEqual(code, {0: bitarray()})

code = huffman_code({0: 0, 1: 0})
self.assertEqual(len(code), 2)
for i in range(2):
self.assertEqual(len(code[i]), 1)

def test_endianness(self):
freq = {'A': 10, 'B': 2, 'C': 5}
for endian in 'big', 'little':
Expand All @@ -1497,8 +1487,20 @@ def test_wrong_arg(self):
self.assertRaises(TypeError, huffman_code, None)
# cannot compare 'a' with 1
self.assertRaises(TypeError, huffman_code, {'A': 'a', 'B': 1})
# frequency map cannot be empty
self.assertRaises(ValueError, huffman_code, {})

def test_one_symbol(self):
cnt = {'a': 1}
code = huffman_code(cnt)
self.assertEqual(code, {'a': bitarray('0')})
for n in range(4):
msg = n * ['a']
a = bitarray()
a.encode(code, msg)
self.assertEqual(a.to01(), n * '0')
self.assertEqual(a.decode(code), msg)

def check_tree(self, code):
n = len(code)
tree = decodetree(code)
Expand Down Expand Up @@ -1568,12 +1570,24 @@ def test_basic(self):

def test_canonical_huffman_errors(self):
self.assertRaises(TypeError, canonical_huffman, [])
# frequency map cannot be empty
self.assertRaises(ValueError, canonical_huffman, {})
self.assertRaises(TypeError, canonical_huffman)
cnt = huffman_code(Counter('aabc'))
self.assertRaises(TypeError, canonical_huffman, cnt, 'a')
cnt = {'a': 1} # only one symbol
self.assertRaises(ValueError, canonical_huffman, cnt)

def test_one_symbol(self):
cnt = {'a': 1}
chc, count, symbol = canonical_huffman(cnt)
self.assertEqual(chc, {'a': bitarray('0')})
self.assertEqual(count, [0, 1])
self.assertEqual(symbol, ['a'])
for n in range(4):
msg = n * ['a']
a = bitarray()
a.encode(chc, msg)
self.assertEqual(a.to01(), n * '0')
self.assertEqual(list(canonical_decode(a, count, symbol)), msg)

def test_canonical_decode_errors(self):
a = bitarray('1101')
Expand Down
17 changes: 13 additions & 4 deletions bitarray/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 +379,20 @@ def huffman_code(__freq_map, endian=None):
bitarrays (with given endianness). Note that the symbols are not limited
to being strings. Symbols may may be any hashable object (such as `None`).
"""

if not isinstance(__freq_map, dict):
raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__)
if len(__freq_map) == 0:
raise ValueError("non-empty dict expected")
if endian is None:
endian = get_default_endian()

b0 = bitarray('0', endian)
b1 = bitarray('1', endian)

if len(__freq_map) < 2:
if len(__freq_map) == 0:
raise ValueError("cannot create Huffman code with no symbols")
# technically not a Huffman tree but what one would expect
return {list(__freq_map)[0]: b0}

result = {}

def traverse(nd, prefix=bitarray(0, endian)):
Expand Down Expand Up @@ -416,8 +420,13 @@ def canonical_huffman(__freq_map):
"""
if not isinstance(__freq_map, dict):
raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__)

if len(__freq_map) < 2:
raise ValueError("at least 2 symbols expected in frequency map")
if len(__freq_map) == 0:
raise ValueError("cannot create Huffman code with no symbols")
# technically not a Huffman tree but what one would expect
sym = list(__freq_map)[0]
return {sym: bitarray('0')}, [0, 1], [sym]

code_length = {} # map symbols to their code length

Expand Down

0 comments on commit 7d80bae

Please sign in to comment.