Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ability (and tests of) integer, float, tuple, and frozenset keys #74

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions sqlitedict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import random
import logging
import traceback
import json

from threading import Thread

Expand Down Expand Up @@ -213,7 +214,7 @@ def __bool__(self):
def iterkeys(self):
GET_KEYS = 'SELECT key FROM "%s" ORDER BY rowid' % self.tablename
for key in self.conn.select(GET_KEYS):
yield key[0]
yield _unconvertkey(key[0])
Copy link
Collaborator

@mpenkov mpenkov Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a more helpful function name, like serialize (and conversely, deserialize).


def itervalues(self):
GET_VALUES = 'SELECT value FROM "%s" ORDER BY rowid' % self.tablename
Expand All @@ -223,7 +224,7 @@ def itervalues(self):
def iteritems(self):
GET_ITEMS = 'SELECT key, value FROM "%s" ORDER BY rowid' % self.tablename
for key, value in self.conn.select(GET_ITEMS):
yield key, self.decode(value)
yield _unconvertkey(key), self.decode(value)

def keys(self):
return self.iterkeys() if major_version > 2 else list(self.iterkeys())
Expand All @@ -233,31 +234,36 @@ def values(self):

def items(self):
return self.iteritems() if major_version > 2 else list(self.iteritems())

def __contains__(self, key):
key = _convertkey(key)
HAS_ITEM = 'SELECT 1 FROM "%s" WHERE key = ?' % self.tablename
return self.conn.select_one(HAS_ITEM, (key,)) is not None

def __getitem__(self, key):
key = _convertkey(key)
GET_ITEM = 'SELECT value FROM "%s" WHERE key = ?' % self.tablename
item = self.conn.select_one(GET_ITEM, (key,))
if item is None:
raise KeyError(key)
return self.decode(item[0])

def __setitem__(self, key, value):
key = _convertkey(key)
if self.flag == 'r':
raise RuntimeError('Refusing to write to read-only SqliteDict')

ADD_ITEM = 'REPLACE INTO "%s" (key, value) VALUES (?,?)' % self.tablename
self.conn.execute(ADD_ITEM, (key, self.encode(value)))

def __delitem__(self, key):
key0 = key
key = _convertkey(key)
if self.flag == 'r':
raise RuntimeError('Refusing to delete from read-only SqliteDict')

if key not in self:
raise KeyError(key)
raise KeyError(key0)
DEL_ITEM = 'DELETE FROM "%s" WHERE key = ?' % self.tablename
self.conn.execute(DEL_ITEM, (key,))

Expand Down Expand Up @@ -356,6 +362,31 @@ def __del__(self):
# in __del__ method.
pass

def _convertkey(key):
if isinstance(key,int):
return '___.int.' + repr(key)
elif isinstance(key,float):
return '___.float.' + repr(key)
# These are recursive
elif isinstance(key,tuple):
return '___.tuple.' + json.dumps([_convertkey(k) for k in key])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be very helpful to allow passing in a JSONEncoder to be used here as a cls=.. argument, to allow inclusion of types which need help serialising.

This could be done after your PR, if you like, but that may be a useful element in the design of this PR.

elif isinstance(key,frozenset):
return '___.frozenset.' + json.dumps(sorted(_convertkey(k) for k in key))
return key
def _unconvertkey(key):
if key.startswith('___.'):
_,keytype,newkey = key.split('.',2)
if keytype == 'int':
return int(newkey)
elif keytype == 'float':
return float(newkey)
elif keytype == 'tuple':
return tuple(_unconvertkey(k) for k in json.loads(newkey))
elif keytype == 'frozenset':
return frozenset(_unconvertkey(k) for k in json.loads(newkey))
# Otherwise do nothing and return
return key

# Adding extra methods for python 2 compatibility (at import time)
if major_version == 2:
SqliteDict.__nonzero__ = SqliteDict.__bool__
Expand Down
21 changes: 20 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ def test_commit_nonblocking(self):
d['key'] = 'value'
d.commit(blocking=False)

def test_special_keys(self):
"""integer, float and/or tuple keys"""
db = SqliteDict()
db['1'] = 1
db[1] = 'ONE'
db[('a',1)] = 'testtuple'
db[frozenset([1,2,'2'])] = 'testfrozenset'
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should have assertions about the underlying generated key names, possibly as separate test method, but also here so it is possible to understand the implementation by reading the tests

assert db[1] == 'ONE'
assert db['1'] == 1
assert db[('a',1)] == 'testtuple'
assert db[frozenset([1,2,'2'])] == 'testfrozenset'

# This tests the reverse conversion
keys = list(db.keys())
assert len(keys) == 4
assert '1' in keys
assert 1 in keys
assert ('a',1) in keys
assert frozenset([1,2,'2']) in keys

class NamedSqliteDictCreateOrReuseTest(TempSqliteDictTest):
"""Verify default flag='c', and flag='n' of SqliteDict()."""
Expand Down Expand Up @@ -279,4 +298,4 @@ def test_tablenames(self):
self.assertEqual(SqliteDict.get_tablenames(fname), ['table1','table2'])

tablenames = SqliteDict.get_tablenames('tests/db/tablenames-test-2.sqlite')
self.assertEqual(tablenames, ['table1','table2'])
self.assertEqual(tablenames, ['table1','table2'])