From 331a61dda142ae620b1491e18575e82569350955 Mon Sep 17 00:00:00 2001 From: comaniac Date: Thu, 22 Aug 2019 13:08:39 -0700 Subject: [PATCH 1/2] [AutoTVM] Fix database APIs --- python/tvm/autotvm/database.py | 24 +++++++++++-------- .../python/unittest/test_autotvm_database.py | 12 ++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/python/tvm/autotvm/database.py b/python/tvm/autotvm/database.py index 9490cfe9cbc2..391b86d45817 100644 --- a/python/tvm/autotvm/database.py +++ b/python/tvm/autotvm/database.py @@ -122,7 +122,8 @@ def get(self, key): def load(self, inp, get_all=False): current = self.get(measure_str_key(inp)) if current is not None: - current = str(current) + if isinstance(current, bytes): + current = current.decode() records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] results = [rec[1] for rec in records] if get_all: @@ -132,6 +133,8 @@ def load(self, inp, get_all=False): def save(self, inp, res, extend=False): current = self.get(measure_str_key(inp)) + if isinstance(current, bytes): + current = current.decode() if not extend or current is None: self.set(measure_str_key(inp), RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)])) @@ -142,29 +145,33 @@ def save(self, inp, res, extend=False): def filter(self, func): """ - Dump all of the records for a particular target + Dump all of the records that match the given rule Parameters ---------- func: callable - The signature of the function is bool (MeasureInput, Array of MeasureResult) + The signature of the function is (MeasureInput, [MeasureResult]) -> bool Returns ------- - list of records (inp, result) matching the target + list of records in tuple (MeasureInput, MeasureResult) matching the rule Examples -------- get records for a target >>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys) + get records with errors + >>> db.filter(lambda inp, results: any(r.error_no != 0 for r in results)) """ matched_records = list() # may consider filtering in iterator in the future - for key in self.db: + for key in self.db.keys(): current = self.get(key) + if isinstance(current, bytes): + current = current.decode() try: - records = [decode(x) for x in current.spilt(RedisDatabase.MAGIC_SPLIT)] - except TypeError: # got a badly formatted/old format record + records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] + except TypeError: # got a badly formatted/old format record continue inps, results = zip(*records) @@ -190,8 +197,5 @@ def __init__(self): def set(self, key, value): self.db[key] = value - def get(self, key): - return self.db.get(key) - def flush(self): self.db = {} diff --git a/tests/python/unittest/test_autotvm_database.py b/tests/python/unittest/test_autotvm_database.py index 4c530dc83f7a..3884444bb496 100644 --- a/tests/python/unittest/test_autotvm_database.py +++ b/tests/python/unittest/test_autotvm_database.py @@ -99,8 +99,20 @@ def test_db_latest_all(): assert encode(inp1, load4[1]) == encode(inp1, res2) assert encode(inp1, load4[2]) == encode(inp1, res3) +def test_db_filter(): + logging.info("test db filter ...") + records = get_sample_records(5) + _db = database.DummyDatabase() + _db.flush() + for inp, result in records: + _db.save(inp, result) + + records = _db.filter(lambda inp, ress: any(r.costs[0] <= 2 for r in ress)) + assert len(records) == 2 + if __name__ == '__main__': logging.basicConfig(level=logging.INFO) test_save_load() test_db_hash() test_db_latest_all() + test_db_filter() From 7abc4c90c23dd95a3b04d3075e52cc3b6434813f Mon Sep 17 00:00:00 2001 From: comaniac Date: Tue, 27 Aug 2019 16:22:15 -0700 Subject: [PATCH 2/2] Refactor the byte conversion --- python/tvm/autotvm/database.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/tvm/autotvm/database.py b/python/tvm/autotvm/database.py index 391b86d45817..f820c1234832 100644 --- a/python/tvm/autotvm/database.py +++ b/python/tvm/autotvm/database.py @@ -117,13 +117,12 @@ def set(self, key, value): self.db.set(key, value) def get(self, key): - return self.db.get(key) + current = self.db.get(key) + return current.decode() if isinstance(current, bytes) else current def load(self, inp, get_all=False): current = self.get(measure_str_key(inp)) if current is not None: - if isinstance(current, bytes): - current = current.decode() records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] results = [rec[1] for rec in records] if get_all: @@ -133,8 +132,6 @@ def load(self, inp, get_all=False): def save(self, inp, res, extend=False): current = self.get(measure_str_key(inp)) - if isinstance(current, bytes): - current = current.decode() if not extend or current is None: self.set(measure_str_key(inp), RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)])) @@ -167,8 +164,6 @@ def filter(self, func): # may consider filtering in iterator in the future for key in self.db.keys(): current = self.get(key) - if isinstance(current, bytes): - current = current.decode() try: records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] except TypeError: # got a badly formatted/old format record