diff --git a/dawg_python/dawgs.py b/dawg_python/dawgs.py index ae0f5b5..93e2522 100644 --- a/dawg_python/dawgs.py +++ b/dawg_python/dawgs.py @@ -141,6 +141,39 @@ def keys(self, prefix=""): return res + def children(self, prefix=""): + b_prefix = prefix.encode('utf8') + res = [] + + index = self.dct.follow_bytes(b_prefix, self.dct.ROOT) + if index is None: + return res + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide) + if not edge_follower.start(index, b_prefix): + return res + + res.append(edge_follower.get_cur_child()) + while edge_follower.next(): + res.append(edge_follower.get_cur_child()) + + return res + + def iterchildren(self, prefix=""): + b_prefix = prefix.encode('utf8') + + index = self.dct.follow_bytes(b_prefix, self.dct.ROOT) + if index is None: + return + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide) + if not edge_follower.start(index, b_prefix): + return + + yield edge_follower.get_cur_child() + while edge_follower.next(): + yield edge_follower.get_cur_child() + def iterkeys(self, prefix=""): b_prefix = prefix.encode('utf8') index = self.dct.follow_bytes(b_prefix, self.dct.ROOT) @@ -279,15 +312,14 @@ def iterkeys(self, prefix=""): yield u_key def items(self, prefix=""): + index = self.dct.ROOT if not isinstance(prefix, bytes): prefix = prefix.encode('utf8') - res = [] - - index = self.dct.ROOT if prefix: index = self.dct.follow_bytes(prefix, index) if not index: - return res + return [] + res = [] completer = wrapper.Completer(self.dct, self.guide) completer.start(index, prefix) @@ -301,10 +333,9 @@ def items(self, prefix=""): return res def iteritems(self, prefix=""): + index = self.dct.ROOT if not isinstance(prefix, bytes): prefix = prefix.encode('utf8') - - index = self.dct.ROOT if prefix: index = self.dct.follow_bytes(prefix, index) if not index: @@ -315,9 +346,95 @@ def iteritems(self, prefix=""): while completer.next(): key, value = completer.key.split(self._payload_separator) - item = (key.decode('utf8'), a2b_base64(bytes(value))) # bytes() cast is a python 2.6 fix + # bytes() cast is a python 2.6 fix + item = (key.decode('utf8'), a2b_base64(bytes(value))) yield item + def children(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + res = [] + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide, + self._payload_separator) + if not edge_follower.start(index, prefix): + return res + + val = True if self._follow_key(bytes(edge_follower.key)) else False + res.append((edge_follower.decoded_key, val)) + while edge_follower.next(): + val = True if self._follow_key(bytes(edge_follower.key)) else False + res.append((edge_follower.decoded_key, val)) + return res + + def iterchildren(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide, + self._payload_separator) + if not edge_follower.start(index, prefix): + return + + val = True if self._follow_key(bytes(edge_follower.key)) else False + yield (edge_follower.decoded_key, val) + while edge_follower.next(): + val = True if self._follow_key(bytes(edge_follower.key)) else False + yield (edge_follower.decoded_key, val) + + def children_data(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + res = [] + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide, + self._payload_separator) + if not edge_follower.start(index, prefix): + return res + + vals = self.b_get_value(bytes(edge_follower.key)) or [None] + res.extend([(edge_follower.decoded_key, val) for val in vals]) + while edge_follower.next(): + vals = self.b_get_value(bytes(edge_follower.key)) or [None] + res.extend([(edge_follower.decoded_key, val) for val in vals]) + return res + + def iterchildren_data(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide, + self._payload_separator) + if not edge_follower.start(index, prefix): + return + + vals = self.b_get_value(bytes(edge_follower.key)) or [None] + for val in vals: + yield (edge_follower.decoded_key, val) + while edge_follower.next(): + vals = self.b_get_value(bytes(edge_follower.key)) or [None] + for val in vals: + yield (edge_follower.decoded_key, val) def _has_value(self, index): return self.dct.follow_bytes(PAYLOAD_SEPARATOR, index) @@ -368,7 +485,6 @@ def similar_items(self, key, replaces): """ return self._similar_items("", key, self.dct.ROOT, replaces) - def _similar_item_values(self, start_pos, key, index, replace_chars): res = [] end_pos = len(key) @@ -424,15 +540,17 @@ def _value_for_index(self, index): def items(self, prefix=""): res = super(RecordDAWG, self).items(prefix) + print("items data:") + print(res) return [(key, self._struct.unpack(val)) for (key, val) in res] def iteritems(self, prefix=""): res = super(RecordDAWG, self).iteritems(prefix) return ((key, self._struct.unpack(val)) for (key, val) in res) - LOOKUP_ERROR = -1 + class IntDAWG(DAWG): """ Dict-like class based on DAWG. @@ -464,6 +582,80 @@ class IntCompletionDAWG(CompletionDAWG, IntDAWG): Dict-like class based on DAWG. It can store integer values for unicode keys and support key completion. """ + def children(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + res = [] + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide) + if not edge_follower.start(index, prefix): + return res + + res.append((edge_follower.decoded_key, edge_follower.has_value())) + while edge_follower.next(): + res.append((edge_follower.decoded_key, edge_follower.has_value())) + + return res + + def iterchildren(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide) + if not edge_follower.start(index, prefix): + return + + yield (edge_follower.decoded_key, edge_follower.has_value()) + while edge_follower.next(): + yield (edge_follower.decoded_key, edge_follower.has_value()) + + def children_data(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + res = [] + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide) + if not edge_follower.start(index, prefix): + return res + + res.append((edge_follower.decoded_key, edge_follower.value())) + while edge_follower.next(): + res.append((edge_follower.decoded_key, edge_follower.value())) + + return res + + def iterchildren_data(self, prefix=""): + index = self.dct.ROOT + if not isinstance(prefix, bytes): + prefix = prefix.encode('utf8') + if prefix: + index = self.dct.follow_bytes(prefix, index) + if not index: + return + + edge_follower = wrapper.EdgeFollower(self.dct, self.guide) + if not edge_follower.start(index, prefix): + return + + yield (edge_follower.decoded_key, edge_follower.value()) + while edge_follower.next(): + yield (edge_follower.decoded_key, edge_follower.value()) + def items(self, prefix=""): if not isinstance(prefix, bytes): prefix = prefix.encode('utf8') diff --git a/dawg_python/wrapper.py b/dawg_python/wrapper.py index 863faf8..be49230 100644 --- a/dawg_python/wrapper.py +++ b/dawg_python/wrapper.py @@ -17,29 +17,29 @@ def __init__(self): "Root index" def has_value(self, index): - "Checks if a given index is related to the end of a key." + #Checks if a given index is related to the end of a key. return units.has_leaf(self._units[index]) def value(self, index): - "Gets a value from a given index." + #Gets a value from a given index. offset = units.offset(self._units[index]) value_index = (index ^ offset) & units.PRECISION_MASK return units.value(self._units[value_index]) def read(self, fp): - "Reads a dictionary from an input stream." + #Reads a dictionary from an input stream. base_size = struct.unpack(str("=I"), fp.read(4))[0] self._units.fromfile(fp, base_size) def contains(self, key): - "Exact matching." + #Exact matching. index = self.follow_bytes(key, self.ROOT) if index is None: return False return self.has_value(index) def find(self, key): - "Exact matching (returns value)" + #Exact matching (returns value) index = self.follow_bytes(key, self.ROOT) if index is None: return -1 @@ -48,7 +48,7 @@ def find(self, key): return self.value(index) def follow_char(self, label, index): - "Follows a transition" + #Follows a transition offset = units.offset(self._units[index]) next_index = (index ^ offset ^ label) & units.PRECISION_MASK @@ -58,7 +58,7 @@ def follow_char(self, label, index): return next_index def follow_bytes(self, s, index): - "Follows transitions." + #Follows transitions. for ch in s: index = self.follow_char(int_from_byte(ch), index) if index is None: @@ -95,6 +95,99 @@ def size(self): return len(self._units) +class EdgeFollower(object): + def __init__(self, dic=None, guide=None, payload_separator=b'\x01'): + self._payload_separator = ord(payload_separator) + self._dic = dic + self._guide = guide + + def value(self): + "provides list of values at current index" + + if self._dic.has_value(self._cur_index): + return self._dic.value(self._cur_index) + return None + + def has_value(self): + "boolean telling whether or not cur_index has a value" + if self._dic.has_value(self._cur_index): + return True + return False + + def start(self, index, prefix=b""): + """initial setup for the next() action on some prefix. If there's a + child for this prefix, we add that as the one item on the index_stack. + Otherwise, leave the stack empty, so next() fails""" + + self.key = bytearray(prefix) + self.base_key_len = len(self.key) + self._parent_index = index + self._sib_index_stack = [] + if self._guide.size(): + child_label = self._guide.child(index) + if child_label: + # Follows a transition to the first child. + child_index = self._dic.follow_char(child_label, index) + if index is not None: + self._sib_index_stack.append( + (child_index, 0, None, bytearray())) + #skip if the child is \x01 (the divider char) + if child_label == self._payload_separator: + return self.next() + else: + return self._get_next_multibyte( + child_label, child_index, None, bytearray()) + return False + + def _get_next_multibyte(self, child_label, index, lvls=None, + part_key=None): + """given some child_label and its index, goes down the approp num levels + to get the first decodable chr""" + part_key.append(child_label) + if lvls is None: + lvls = levels_to_descend(child_label) + if lvls > 0: + for i in reversed(range(lvls)): + next_child_label = self._guide.child(index) + prev_index = index + index = self._dic.follow_char(next_child_label, index) + self._sib_index_stack.append( + (index, i, prev_index, part_key[:])) + part_key.append(next_child_label) + self.key.extend(part_key) + self.decoded_key = self.key.decode('utf8') + self._cur_index = index + return True + + def next(self): + "Gets the next child (not necessarily a terminal)" + + if not self._sib_index_stack: + return False + sib_index, lvls, parent_index, part_key = self._sib_index_stack.pop() + if not parent_index: + parent_index = self._parent_index + sibling_label = self._guide.sibling(sib_index) + sib_index = self._dic.follow_char(sibling_label, parent_index) + if not sib_index: + return self.next() + if lvls == 0: + lvls = None + self._sib_index_stack.append( + (sib_index, lvls, parent_index, part_key[:])) + if sibling_label == self._payload_separator: + return self.next() + self.key = self.key[:self.base_key_len] + return self._get_next_multibyte(sibling_label, sib_index, lvls, + part_key) + + def get_cur_child(self): + """helper method for getting the decoded key along with whether or not + it is a terminal""" + + return (self.decoded_key, self._dic.has_value(self._cur_index)) + + class Completer(object): def __init__(self, dic=None, guide=None): @@ -102,9 +195,13 @@ def __init__(self, dic=None, guide=None): self._guide = guide def value(self): + "provides list of values at current index" + return self._dic.value(self._last_index) def start(self, index, prefix=b""): + "initial setup for a completer next() action on some prefix" + self.key = bytearray(prefix) if self._guide.size(): @@ -113,7 +210,6 @@ def start(self, index, prefix=b""): else: self._index_stack = [] - def next(self): "Gets the next key" @@ -153,7 +249,6 @@ def next(self): return self._find_terminal(index) - def _follow(self, label, index): next_index = self._dic.follow_char(label, index) if next_index is None: @@ -176,3 +271,17 @@ def _find_terminal(self, index): self._last_index = index return True + + +#the first byte in a utf-8 char determines how many total bytes are in the char. +#the number of bytes = number of leading ones in first byte (i.e. e5 = 225 = +#3 bytes (including the first) +def levels_to_descend(byte_val): + if byte_val < 192: + return 0 + elif byte_val < 224: + return 1 + elif byte_val < 240: + return 2 + else: + return 3 diff --git a/dev_data/small/bytes.dawg b/dev_data/small/bytes.dawg index debaacb..9ca2377 100644 Binary files a/dev_data/small/bytes.dawg and b/dev_data/small/bytes.dawg differ diff --git a/dev_data/small/int_completion_dawg.dawg b/dev_data/small/int_completion_dawg.dawg index a25033e..8d8794c 100644 Binary files a/dev_data/small/int_completion_dawg.dawg and b/dev_data/small/int_completion_dawg.dawg differ diff --git a/dev_data/small/int_dawg.dawg b/dev_data/small/int_dawg.dawg index 2c42a5c..db39e42 100644 Binary files a/dev_data/small/int_dawg.dawg and b/dev_data/small/int_dawg.dawg differ diff --git a/tests/test_dawg.py b/tests/test_dawg.py index 0c74e90..81fdd77 100644 --- a/tests/test_dawg.py +++ b/tests/test_dawg.py @@ -8,6 +8,7 @@ from .utils import data_path + def test_c_dawg_contains(): dawg = pytest.importorskip("dawg") # import dawg bin_dawg = dawg.IntDAWG({'foo': 1, 'bar': 2, 'foobar': 3}) @@ -30,7 +31,8 @@ class TestCompletionDAWG(object): keys = ['f', 'bar', 'foo', 'foobar'] def dawg(self): - return dawg_python.CompletionDAWG().load(data_path('small', 'completion.dawg')) + return dawg_python.CompletionDAWG().load(data_path('small', + 'completion.dawg')) def test_contains(self): d = self.dawg() @@ -46,10 +48,24 @@ def test_keys(self): d = self.dawg() assert d.keys() == sorted(self.keys) + def test_children(self): + d = self.dawg() + assert d.children() == [('b', False), ('f', True)] + assert d.children('b') == [('ba', False)] + assert d.children('fo') == [('foo', True)] + assert d.children('foobar') == [] + def test_iterkeys(self): d = self.dawg() assert list(d.iterkeys()) == d.keys() + def test_iter_children(self): + d = self.dawg() + assert list(d.iterchildren()) == [('b', False), ('f', True)] + assert list(d.iterchildren('b')) == [('ba', False)] + assert list(d.children('fo')) == [('foo', True)] + assert list(d.children('foobar')) == [] + def test_completion(self): d = self.dawg() @@ -77,9 +93,8 @@ def test_prefixes(self): assert d.prefixes("bar") == ["bar"] - class TestIntDAWG(object): - payload = {'foo': 1, 'bar': 5, 'foobar': 3} + payload = {'foo': 1, 'bar': 5, 'foobar': 30} def dawg(self): return dawg_python.IntDAWG().load(data_path('small', 'int_dawg.dawg')) @@ -119,3 +134,27 @@ def test_completion_keys_with_prefix(self): def test_completion_items(self): assert self.dawg().items() == sorted(self.payload.items(), key=lambda r: r[0]) + + def test_completion_children(self): + assert self.dawg().children('ba') == [('bar', True)] + assert self.dawg().children('foob') == [('fooba', False)] + assert self.dawg().children('fooba') == [('foobar', True)] + assert self.dawg().children('foobar') == [] + + def test_completion_iterchildren(self): + assert list(self.dawg().iterchildren('ba')) == [('bar', True)] + assert list(self.dawg().iterchildren('foob')) == [('fooba', False)] + assert list(self.dawg().iterchildren('fooba')) == [('foobar', True)] + assert list(self.dawg().iterchildren('foobar')) == [] + + def test_completion_children_data(self): + assert self.dawg().children_data('ba') == [('bar', 5)] + assert self.dawg().children_data('foob') == [('fooba', None)] + assert self.dawg().children_data('fooba') == [('foobar', 30)] + assert self.dawg().children_data('foobar') == [] + + def test_completion_iterchildren_data(self): + assert list(self.dawg().iterchildren_data('ba')) == [('bar', 5)] + assert list(self.dawg().iterchildren_data('foob')) == [('fooba', None)] + assert list(self.dawg().iterchildren_data('fooba')) == [('foobar', 30)] + assert list(self.dawg().iterchildren_data('foobar')) == [] diff --git a/tests/test_payload_dawg.py b/tests/test_payload_dawg.py index 4f9060d..0fa0825 100644 --- a/tests/test_payload_dawg.py +++ b/tests/test_payload_dawg.py @@ -11,7 +11,10 @@ class TestBytesDAWG(object): ('foo', b'data1'), ('bar', b'data2'), ('foo', b'data3'), - ('foobar', b'data4') + ('foobar', b'data4'), + ('ሀ', b'ethiopic_sign1'), + ('ሮ', b'ethiopic_sign2'), + ('ቄ', b'ethiopic_sign3') ) def dawg(self): @@ -33,6 +36,7 @@ def test_getitem(self): assert d['foo'] == [b'data1', b'data3'] assert d['bar'] == [b'data2'] assert d['foobar'] == [b'data4'] + assert d['\u1244'] == [b'ethiopic_sign3'] def test_getitem_missing(self): @@ -52,12 +56,47 @@ def test_getitem_missing(self): def test_keys(self): d = self.dawg() - assert d.keys() == ['bar', 'foo', 'foo', 'foobar'] + assert d.keys() == ['bar', 'foo', 'foo', 'foobar', 'ሀ', 'ሮ', 'ቄ'] def test_iterkeys(self): d = self.dawg() assert list(d.iterkeys()) == d.keys() + def test_children(self): + d = self.dawg() + assert d.children('foob') == [('fooba', False)] + assert d.children('fooba') == [('foobar', True)] + assert d.children('fo') == [('foo', True)] + assert d.children('foo') == [('foob', False)] + + def test_iterchildren(self): + d = self.dawg() + assert list(d.iterchildren('foob')) == [('fooba', False)] + assert list(d.iterchildren('fooba')) == [('foobar', True)] + assert list(d.iterchildren('fo')) == [('foo', True)] + assert list(d.iterchildren('foo')) == [('foob', False)] + + def test_children_data(self): + d = self.dawg() + assert d.children_data('foob') == [('fooba', None)] + assert d.children_data('fooba') == [('foobar', b'data4')] + assert d.children_data('fo') == [('foo', b'data1'), ('foo', b'data3')] + assert d.children_data('foobar') == [] + assert d.children_data('foo') == [('foob', None)] + assert set(d.children_data('')) == set([('b', None), ('f', None), + ('ሀ', b'ethiopic_sign1'), + ('ሮ', b'ethiopic_sign2'), + ('ቄ', b'ethiopic_sign3')]) + + def test_iterchildren_data(self): + d = self.dawg() + assert list(d.iterchildren_data('foob')) == [('fooba', None)] + assert list(d.iterchildren_data('fooba')) == [('foobar', b'data4')] + assert list(d.iterchildren_data('fo')) == \ + [('foo', b'data1'), ('foo', b'data3')] + assert list(d.iterchildren_data('foobar')) == [] + assert list(d.iterchildren_data('foo')) == [('foob', None)] + def test_key_completion(self): d = self.dawg() assert d.keys('fo') == ['foo', 'foo', 'foobar'] @@ -65,6 +104,7 @@ def test_key_completion(self): def test_items(self): d = self.dawg() assert d.items() == sorted(self.DATA) + assert d.items('not a real key') == [] def test_iteritems(self): d = self.dawg() @@ -75,6 +115,8 @@ def test_iteritems(self): def test_items_completion(self): d = self.dawg() assert d.items('foob') == [('foobar', b'data4')] + assert d.items('foo') == [('foo', b'data1'), ('foo', b'data3'), + ('foobar', b'data4')] def test_prefixes(self): d = self.dawg() @@ -121,6 +163,19 @@ def test_record_items(self): d = self.dawg() assert d.items() == sorted(self.STRUCTURED_DATA) + def test_children_data(self): + d = self.dawg() + assert d.children_data('foob') == [('fooba', None)] + assert d.children_data('fooba') == [('foobar', (6, 3, 0))] + assert d.children_data('fo') == [('foo', (3, 2, 1)), ('foo', (3, 2, 256))] + + def test_iterchildren_data(self): + d = self.dawg() + assert list(d.iterchildren_data('foob')) == [('fooba', None)] + assert list(d.iterchildren_data('fooba')) == [('foobar', (6, 3, 0))] + assert list(d.iterchildren_data('fo')) == [('foo', (3, 2, 1)), + ('foo', (3, 2, 256))] + def test_record_keys(self): d = self.dawg() assert d.keys() == ['bar', 'foo', 'foo', 'foobar',]