diff --git a/redis/_compat.py b/redis/_compat.py index c7859b571b..38d767d435 100644 --- a/redis/_compat.py +++ b/redis/_compat.py @@ -12,6 +12,16 @@ except ImportError: from StringIO import StringIO as BytesIO + # special unicode handling for python2 to avoid UnicodeDecodeError + def safe_unicode(obj, *args): + """ return the unicode representation of obj """ + try: + return unicode(obj, *args) + except UnicodeDecodeError: + # obj is byte string + ascii_text = str(obj).encode('string_escape') + return unicode(ascii_text) + iteritems = lambda x: x.iteritems() iterkeys = lambda x: x.iterkeys() itervalues = lambda x: x.itervalues() @@ -48,6 +58,7 @@ xrange = range basestring = str unicode = str + safe_unicode = str bytes = bytes long = int diff --git a/redis/client.py b/redis/client.py index e4445713eb..602136f04c 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,7 +6,7 @@ import threading import time as mod_time from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, - itervalues, izip, long, nativestr, unicode) + itervalues, izip, long, nativestr, unicode, safe_unicode) from redis.connection import (ConnectionPool, UnixDomainSocketConnection, SSLConnection, Token) from redis.lock import Lock, LuaLock @@ -2532,9 +2532,9 @@ def raise_first_error(self, commands, response): raise r def annotate_exception(self, exception, number, command): - cmd = unicode(' ').join(imap(unicode, command)) + cmd = safe_unicode(' ').join(imap(safe_unicode, command)) msg = unicode('Command # %d (%s) of pipeline caused error: %s') % ( - number, cmd, unicode(exception.args[0])) + number, cmd, safe_unicode(exception.args[0])) exception.args = (msg,) + exception.args[1:] def parse_response(self, connection, command_name, **options): diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 4849c81b7c..2e6f549b94 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -10,6 +10,16 @@ value = tonumber(value) return value * ARGV[1]""" +msgpack_hello_script = """ +local message = cmsgpack.unpack(ARGV[1]) +local name = message['name'] +return "hello " .. name +""" +msgpack_hello_script_broken = """ +local message = cmsgpack.unpack(ARGV[1]) +local names = message['name'] +return "hello " .. name +""" class TestScripting(object): @pytest.fixture(autouse=True) @@ -80,3 +90,24 @@ def test_script_object_in_pipeline(self, r): assert r.script_exists(multiply.sha) == [False] # [SET worked, GET 'a', result of multiple script] assert pipe.execute() == [True, b('2'), 6] + + def test_eval_msgpack_pipeline_error_in_lua(self, r): + msgpack_hello = r.register_script(msgpack_hello_script) + assert not msgpack_hello.sha + + pipe = r.pipeline() + + # avoiding a dependency to msgpack, this is the output of msgpack.dumps({"name": "joe"}) + msgpack_message_1 = b'\x81\xa4name\xa3Joe' + + msgpack_hello(args=[msgpack_message_1], client = pipe) + + assert r.script_exists(msgpack_hello.sha) == [True] + assert pipe.execute()[0] == b'hello Joe' + + msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) + + msgpack_hello_broken(args=[msgpack_message_1], client = pipe) + with pytest.raises(exceptions.ResponseError) as excinfo: + pipe.execute() + assert excinfo.type == exceptions.ResponseError \ No newline at end of file