diff --git a/pwnlib/util/packing.py b/pwnlib/util/packing.py index 4c5188402..7e032409d 100644 --- a/pwnlib/util/packing.py +++ b/pwnlib/util/packing.py @@ -32,6 +32,8 @@ """ from __future__ import absolute_import +from collections import Counter +from io import BytesIO import struct import sys @@ -479,10 +481,37 @@ def make_unpacker(word_size = None, endianness = None, sign = None, **kwargs): else: return lambda number: unpack(number, word_size, endianness, sign) +class Labelled(object): + def __init__(self, label, contents): + self.label = label + self.contents = contents +class Label(object): + def __init__(self, begin = 1, size = 1): + self._begin = begin + self._size = size -def _flat(args, preprocessor, packer): - out = [] + def mark(self, contents): + return Labelled(self, contents) + + def begin(self, unit = 1): + return self._begin / unit + + def end(self, unit = 1): + return (self._begin + self._size + (unit - 1)) / unit + + def size(self, unit = 1): + return (self._size + (unit - 1)) / unit + + def assign(self, begin, size): + changed = self._begin != begin or self._size != size + self._begin = begin + self._size = size + return changed + +def _flat(args, preprocessor, packer, initial_offset): + out = BytesIO() + changed_labels = [] for arg in args: if not isinstance(arg, (list, tuple)): @@ -491,20 +520,31 @@ def _flat(args, preprocessor, packer): arg = arg_ if hasattr(arg, '__flat__'): - out.append(arg.__flat__()) + out.write(arg.__flat__()) elif isinstance(arg, (list, tuple)): - out.append(_flat(arg, preprocessor, packer)) + changed, contents = _flat(arg, preprocessor, packer, initial_offset + out.tell()) + changed_labels += changed + out.write(contents) elif isinstance(arg, str): - out.append(arg) + out.write(arg) elif isinstance(arg, unicode): - out.append(arg.encode('utf8')) + out.write(arg.encode('utf8')) elif isinstance(arg, (int, long)): - out.append(packer(arg)) + out.write(packer(arg)) elif isinstance(arg, bytearray): - out.append(str(arg)) + out.write(str(arg)) + elif isinstance(arg, Labelled): + position = initial_offset + out.tell() + + changed, contents = _flat(arg.contents, preprocessor, packer, position) + out.write(contents) + + if arg.label.assign(position, len(contents)): + changed_labels.append(arg.label) + changed_labels += changed else: raise ValueError("flat(): Flat does not support values of type %s" % type(arg)) - return ''.join(out) + return changed_labels, out.getvalue() @LocalContext def flat(*args, **kwargs): @@ -542,7 +582,21 @@ def flat(*args, **kwargs): if kwargs != {}: raise TypeError("flat() does not support argument %r" % kwargs.popitem()[0]) - return _flat(args, preprocessor, make_packer(word_size)) + MAX_ITERATIONS = 100000 + for _ in xrange(MAX_ITERATIONS): + called = [a() if callable(a) else a for a in args] + changed_labels, result = _flat(called, preprocessor, make_packer(word_size), 0) + if not changed_labels: + return result + + if not all(callable(x) for x in args): + raise ValueError("all arguments must be functions when using labels") + + duplicates = [l for l, v in Counter(changed_labels).iteritems() if v > 1] + if duplicates: + raise ValueError("some labels are assigned twice") + + raise RuntimeError("could not find a valid label assignment in %d iterations" % MAX_ITERATIONS) @LocalContext @@ -622,24 +676,48 @@ def fill(out, value): out += filler.next() return out, out.index(value) - # convert str keys to offsets - # convert large int keys to offsets - pieces_ = dict() - for k, v in pieces.items(): - if isinstance(k, (int, long)): - # cyclic() generally starts with 'aaaa' - if k >= 0x61616161: - out, k = fill(out, pack(k)) - elif isinstance(k, str): - out, k = fill(out, k) - else: - raise TypeError("fit(): offset must be of type int or str, but got '%s'" % type(k)) - pieces_[k] = v - pieces = pieces_ - # convert values to their flattened forms - for k,v in pieces.items(): - pieces[k] = _flat([v], preprocessor, packer) + MAX_ITERATIONS = 100000 + original_pieces = pieces + for _ in xrange(MAX_ITERATIONS): + pieces = original_pieces + if callable(pieces): + pieces = pieces() + + # convert str keys to offsets + # convert large int keys to offsets + pieces_ = dict() + for k, v in pieces.items(): + if isinstance(k, (int, long)): + # cyclic() generally starts with 'aaaa' + if k >= 0x61616161: + out, k = fill(out, pack(k)) + elif isinstance(k, str): + out, k = fill(out, k) + else: + raise TypeError("fit(): offset must be of type int or str, but got '%s'" % type(k)) + pieces_[k] = v + pieces = pieces_ + + # convert values to their flattened forms + changed_labels = [] + for k,v in pieces.items(): + new_changed, pieces[k] = _flat([v], preprocessor, packer, k) + changed_labels += new_changed + + # found a valid assignment + if not changed_labels: + break + + duplicates = [l for l, v in Counter(changed_labels).iteritems() if v > 1] + if duplicates: + raise ValueError("fit(): some labels are assigned twice") + + if not callable(original_pieces): + raise ValueError("fit(): argument must be a function when using labels") + + else: + raise RuntimeError("fit(): could not find a valid label assignment in %d iterations" % MAX_ITERATIONS) # if we were provided a length, make sure everything fits last = max(pieces) @@ -654,7 +732,6 @@ def fill(out, value): raise ValueError("fit(): data at offset %d overlaps with previous data which ends at offset %d" % (k, l)) while len(out) < k: out.append(filler.next()) - v = _flat([v], preprocessor, packer) l = k + len(v) # consume the filler for each byte of actual data