Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions redis/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
except ImportError:
from StringIO import StringIO as BytesIO

# special unicode handling for python2 to avoid UnicodeDecodeError
def safe_unicode(obj, *args):
""" return the unicode representation of obj """
try:
return unicode(obj, *args)
except UnicodeDecodeError:
# obj is byte string
ascii_text = str(obj).encode('string_escape')
return unicode(ascii_text)

iteritems = lambda x: x.iteritems()
iterkeys = lambda x: x.iterkeys()
itervalues = lambda x: x.itervalues()
Expand Down Expand Up @@ -48,6 +58,7 @@
xrange = range
basestring = str
unicode = str
safe_unicode = str
bytes = bytes
long = int

Expand Down
7 changes: 4 additions & 3 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import threading
import time as mod_time
from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys,
itervalues, izip, long, nativestr, unicode)
itervalues, izip, long, nativestr, unicode,
safe_unicode)
from redis.connection import (ConnectionPool, UnixDomainSocketConnection,
SSLConnection, Token)
from redis.lock import Lock, LuaLock
Expand Down Expand Up @@ -2532,9 +2533,9 @@ def raise_first_error(self, commands, response):
raise r

def annotate_exception(self, exception, number, command):
cmd = unicode(' ').join(imap(unicode, command))
cmd = safe_unicode(' ').join(imap(safe_unicode, command))
msg = unicode('Command # %d (%s) of pipeline caused error: %s') % (
number, cmd, unicode(exception.args[0]))
number, cmd, safe_unicode(exception.args[0]))
exception.args = (msg,) + exception.args[1:]

def parse_response(self, connection, command_name, **options):
Expand Down
33 changes: 33 additions & 0 deletions tests/test_scripting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
value = tonumber(value)
return value * ARGV[1]"""

msgpack_hello_script = """
local message = cmsgpack.unpack(ARGV[1])
local name = message['name']
return "hello " .. name
"""
msgpack_hello_script_broken = """
local message = cmsgpack.unpack(ARGV[1])
local names = message['name']
return "hello " .. name
"""


class TestScripting(object):
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -80,3 +91,25 @@ def test_script_object_in_pipeline(self, r):
assert r.script_exists(multiply.sha) == [False]
# [SET worked, GET 'a', result of multiple script]
assert pipe.execute() == [True, b('2'), 6]

def test_eval_msgpack_pipeline_error_in_lua(self, r):
msgpack_hello = r.register_script(msgpack_hello_script)
assert not msgpack_hello.sha

pipe = r.pipeline()

# avoiding a dependency to msgpack, this is the output of
# msgpack.dumps({"name": "joe"})
msgpack_message_1 = b'\x81\xa4name\xa3Joe'

msgpack_hello(args=[msgpack_message_1], client=pipe)

assert r.script_exists(msgpack_hello.sha) == [True]
assert pipe.execute()[0] == b'hello Joe'

msgpack_hello_broken = r.register_script(msgpack_hello_script_broken)

msgpack_hello_broken(args=[msgpack_message_1], client=pipe)
with pytest.raises(exceptions.ResponseError) as excinfo:
pipe.execute()
assert excinfo.type == exceptions.ResponseError