diff --git a/sqlitedict.py b/sqlitedict.py index 7b60235..0df6191 100755 --- a/sqlitedict.py +++ b/sqlitedict.py @@ -33,6 +33,7 @@ import random import logging import traceback +import json from threading import Thread @@ -213,7 +214,7 @@ def __bool__(self): def iterkeys(self): GET_KEYS = 'SELECT key FROM "%s" ORDER BY rowid' % self.tablename for key in self.conn.select(GET_KEYS): - yield key[0] + yield _unconvertkey(key[0]) def itervalues(self): GET_VALUES = 'SELECT value FROM "%s" ORDER BY rowid' % self.tablename @@ -223,7 +224,7 @@ def itervalues(self): def iteritems(self): GET_ITEMS = 'SELECT key, value FROM "%s" ORDER BY rowid' % self.tablename for key, value in self.conn.select(GET_ITEMS): - yield key, self.decode(value) + yield _unconvertkey(key), self.decode(value) def keys(self): return self.iterkeys() if major_version > 2 else list(self.iterkeys()) @@ -233,12 +234,14 @@ def values(self): def items(self): return self.iteritems() if major_version > 2 else list(self.iteritems()) - + def __contains__(self, key): + key = _convertkey(key) HAS_ITEM = 'SELECT 1 FROM "%s" WHERE key = ?' % self.tablename return self.conn.select_one(HAS_ITEM, (key,)) is not None def __getitem__(self, key): + key = _convertkey(key) GET_ITEM = 'SELECT value FROM "%s" WHERE key = ?' % self.tablename item = self.conn.select_one(GET_ITEM, (key,)) if item is None: @@ -246,6 +249,7 @@ def __getitem__(self, key): return self.decode(item[0]) def __setitem__(self, key, value): + key = _convertkey(key) if self.flag == 'r': raise RuntimeError('Refusing to write to read-only SqliteDict') @@ -253,11 +257,13 @@ def __setitem__(self, key, value): self.conn.execute(ADD_ITEM, (key, self.encode(value))) def __delitem__(self, key): + key0 = key + key = _convertkey(key) if self.flag == 'r': raise RuntimeError('Refusing to delete from read-only SqliteDict') if key not in self: - raise KeyError(key) + raise KeyError(key0) DEL_ITEM = 'DELETE FROM "%s" WHERE key = ?' % self.tablename self.conn.execute(DEL_ITEM, (key,)) @@ -356,6 +362,31 @@ def __del__(self): # in __del__ method. pass +def _convertkey(key): + if isinstance(key,int): + return '___.int.' + repr(key) + elif isinstance(key,float): + return '___.float.' + repr(key) + # These are recursive + elif isinstance(key,tuple): + return '___.tuple.' + json.dumps([_convertkey(k) for k in key]) + elif isinstance(key,frozenset): + return '___.frozenset.' + json.dumps(sorted(_convertkey(k) for k in key)) + return key +def _unconvertkey(key): + if key.startswith('___.'): + _,keytype,newkey = key.split('.',2) + if keytype == 'int': + return int(newkey) + elif keytype == 'float': + return float(newkey) + elif keytype == 'tuple': + return tuple(_unconvertkey(k) for k in json.loads(newkey)) + elif keytype == 'frozenset': + return frozenset(_unconvertkey(k) for k in json.loads(newkey)) + # Otherwise do nothing and return + return key + # Adding extra methods for python 2 compatibility (at import time) if major_version == 2: SqliteDict.__nonzero__ = SqliteDict.__bool__ diff --git a/tests/test_core.py b/tests/test_core.py index 1e43c4f..8d8e06c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -67,6 +67,25 @@ def test_commit_nonblocking(self): d['key'] = 'value' d.commit(blocking=False) + def test_special_keys(self): + """integer, float and/or tuple keys""" + db = SqliteDict() + db['1'] = 1 + db[1] = 'ONE' + db[('a',1)] = 'testtuple' + db[frozenset([1,2,'2'])] = 'testfrozenset' + assert db[1] == 'ONE' + assert db['1'] == 1 + assert db[('a',1)] == 'testtuple' + assert db[frozenset([1,2,'2'])] == 'testfrozenset' + + # This tests the reverse conversion + keys = list(db.keys()) + assert len(keys) == 4 + assert '1' in keys + assert 1 in keys + assert ('a',1) in keys + assert frozenset([1,2,'2']) in keys class NamedSqliteDictCreateOrReuseTest(TempSqliteDictTest): """Verify default flag='c', and flag='n' of SqliteDict().""" @@ -279,4 +298,4 @@ def test_tablenames(self): self.assertEqual(SqliteDict.get_tablenames(fname), ['table1','table2']) tablenames = SqliteDict.get_tablenames('tests/db/tablenames-test-2.sqlite') - self.assertEqual(tablenames, ['table1','table2']) \ No newline at end of file + self.assertEqual(tablenames, ['table1','table2'])